Skip to content

Commit

Permalink
- Map the VarChar vector to its corresponding native type based on th…
Browse files Browse the repository at this point in the history
…e metadata
  • Loading branch information
SthuthiGhosh9400 authored and elbinpallimalilibm committed Oct 2, 2024
1 parent 3c67c3e commit d329e9f
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.facebook.presto.common.type.TimeType;
import com.facebook.presto.common.type.TimestampType;
import com.facebook.presto.common.type.TinyintType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.VarbinaryType;
import com.facebook.presto.common.type.VarcharType;
import com.facebook.presto.spi.ColumnHandle;
Expand Down Expand Up @@ -106,80 +107,89 @@ public Map<String, ColumnHandle> getColumnHandles(ConnectorSession session, Conn

String schemaValue = ((ArrowTableHandle) tableHandle).getSchema();
String tableValue = ((ArrowTableHandle) tableHandle).getTable();
String dataSourceSpecificSchemaValue = getDataSourceSpecificSchemaName(config, schemaValue);
String dataSourceSpecificTableName = getDataSourceSpecificTableName(config, tableValue);
List<Field> columnList = getColumnsList(dataSourceSpecificSchemaValue, dataSourceSpecificTableName, session);
String dbSpecificSchemaValue = getDataSourceSpecificSchemaName(config, schemaValue);
String dBSpecificTableName = getDataSourceSpecificTableName(config, tableValue);
List<Field> columnList = getColumnsList(dbSpecificSchemaValue, dBSpecificTableName, session);

for (Field field : columnList) {
String columnName = field.getName();
logger.debug("The value of the flight columnName is:- %s", columnName);

ArrowColumnHandle handle;
switch (field.getType().getTypeID()) {
case Int:
ArrowType.Int intType = (ArrowType.Int) field.getType();
switch (intType.getBitWidth()) {
case 64:
column.put(columnName, new ArrowColumnHandle(columnName, BigintType.BIGINT));
break;
case 32:
column.put(columnName, new ArrowColumnHandle(columnName, IntegerType.INTEGER));
break;
case 16:
column.put(columnName, new ArrowColumnHandle(columnName, SmallintType.SMALLINT));
break;
case 8:
column.put(columnName, new ArrowColumnHandle(columnName, TinyintType.TINYINT));
break;
default:
throw new ArrowException(ARROW_FLIGHT_ERROR, "Invalid bit width " + intType.getBitWidth());
}
handle = createArrowColumnHandleForIntType(columnName, intType);
break;
case Binary:
case LargeBinary:
case FixedSizeBinary:
column.put(columnName, new ArrowColumnHandle(columnName, VarbinaryType.VARBINARY));
handle = new ArrowColumnHandle(columnName, VarbinaryType.VARBINARY);
break;
case Date:
column.put(columnName, new ArrowColumnHandle(columnName, DateType.DATE));
handle = new ArrowColumnHandle(columnName, DateType.DATE);
break;
case Timestamp:
column.put(columnName, new ArrowColumnHandle(columnName, TimestampType.TIMESTAMP));
handle = new ArrowColumnHandle(columnName, TimestampType.TIMESTAMP);
break;
case Utf8:
case LargeUtf8:
column.put(columnName, new ArrowColumnHandle(columnName, VarcharType.VARCHAR));
handle = new ArrowColumnHandle(columnName, VarcharType.VARCHAR);
break;
case FloatingPoint:
ArrowType.FloatingPoint floatingPoint = (ArrowType.FloatingPoint) field.getType();
switch (floatingPoint.getPrecision()) {
case SINGLE:
column.put(columnName, new ArrowColumnHandle(columnName, RealType.REAL));
break;
case DOUBLE:
column.put(columnName, new ArrowColumnHandle(columnName, DoubleType.DOUBLE));
break;
default:
throw new ArrowException(ARROW_FLIGHT_ERROR, "Invalid floating point precision " + floatingPoint.getPrecision());
}
handle = createArrowColumnHandleForFloatingPointType(columnName, floatingPoint);
break;
case Decimal:
ArrowType.Decimal decimalType = (ArrowType.Decimal) field.getType();
int precision = decimalType.getPrecision();
int scale = decimalType.getScale();
column.put(columnName, new ArrowColumnHandle(columnName, DecimalType.createDecimalType(precision, scale)));
handle = new ArrowColumnHandle(columnName, DecimalType.createDecimalType(decimalType.getPrecision(), decimalType.getScale()));
break;
case Bool:
column.put(columnName, new ArrowColumnHandle(columnName, BooleanType.BOOLEAN));
handle = new ArrowColumnHandle(columnName, BooleanType.BOOLEAN);
break;
case Time:
column.put(columnName, new ArrowColumnHandle(columnName, TimeType.TIME));
handle = new ArrowColumnHandle(columnName, TimeType.TIME);
break;
default:
throw new UnsupportedOperationException("The data type " + field.getType().getTypeID() + " is not supported.");
}
Type type = overrideFieldType(field, handle.getColumnType());
if (!type.equals(handle.getColumnType())) {
handle = new ArrowColumnHandle(columnName, type);
}
column.put(columnName, handle);
}
return column;
}

private ArrowColumnHandle createArrowColumnHandleForIntType(String columnName, ArrowType.Int intType)
{
switch (intType.getBitWidth()) {
case 64:
return new ArrowColumnHandle(columnName, BigintType.BIGINT);
case 32:
return new ArrowColumnHandle(columnName, IntegerType.INTEGER);
case 16:
return new ArrowColumnHandle(columnName, SmallintType.SMALLINT);
case 8:
return new ArrowColumnHandle(columnName, TinyintType.TINYINT);
default:
throw new ArrowException(ARROW_FLIGHT_ERROR, "Invalid bit width " + intType.getBitWidth());
}
}

private ArrowColumnHandle createArrowColumnHandleForFloatingPointType(String columnName, ArrowType.FloatingPoint floatingPoint)
{
switch (floatingPoint.getPrecision()) {
case SINGLE:
return new ArrowColumnHandle(columnName, RealType.REAL);
case DOUBLE:
return new ArrowColumnHandle(columnName, DoubleType.DOUBLE);
default:
throw new ArrowException(ARROW_FLIGHT_ERROR, "Invalid floating point precision " + floatingPoint.getPrecision());
}
}

@Override
public List<ConnectorTableLayoutResult> getTableLayouts(ConnectorSession session, ConnectorTableHandle table, Constraint<ColumnHandle> constraint, Optional<Set<ColumnHandle>> desiredColumns)
{
Expand Down Expand Up @@ -209,73 +219,96 @@ public ConnectorTableMetadata getTableMetadata(ConnectorSession session, Connect

for (Field field : columnList) {
String columnName = field.getName();
switch (field.getType().getTypeID()) {
ArrowType type = field.getType();

ColumnMetadata columnMetadata;

switch (type.getTypeID()) {
case Int:
ArrowType.Int intType = (ArrowType.Int) field.getType();
switch (intType.getBitWidth()) {
case 64:
meta.add(new ColumnMetadata(columnName, BigintType.BIGINT));
break;
case 32:
meta.add(new ColumnMetadata(columnName, IntegerType.INTEGER));
break;
case 16:
meta.add(new ColumnMetadata(columnName, SmallintType.SMALLINT));
break;
case 8:
meta.add(new ColumnMetadata(columnName, TinyintType.TINYINT));
break;
default:
throw new ArrowException(ARROW_FLIGHT_ERROR, "Invalid bit width " + intType.getBitWidth());
}
ArrowType.Int intType = (ArrowType.Int) type;
columnMetadata = createIntColumnMetadata(columnName, intType);
break;
case Binary:
case LargeBinary:
case FixedSizeBinary:
meta.add(new ColumnMetadata(columnName, VarbinaryType.VARBINARY));
columnMetadata = new ColumnMetadata(columnName, VarbinaryType.VARBINARY);
break;
case Date:
meta.add(new ColumnMetadata(columnName, DateType.DATE));
columnMetadata = new ColumnMetadata(columnName, DateType.DATE);
break;
case Timestamp:
meta.add(new ColumnMetadata(columnName, TimestampType.TIMESTAMP));
columnMetadata = new ColumnMetadata(columnName, TimestampType.TIMESTAMP);
break;
case Utf8:
case LargeUtf8:
meta.add(new ColumnMetadata(columnName, VarcharType.VARCHAR));
columnMetadata = new ColumnMetadata(columnName, VarcharType.VARCHAR);
break;
case FloatingPoint:
ArrowType.FloatingPoint floatingPoint = (ArrowType.FloatingPoint) field.getType();
switch (floatingPoint.getPrecision()) {
case SINGLE:
meta.add(new ColumnMetadata(columnName, RealType.REAL));
break;
case DOUBLE:
meta.add(new ColumnMetadata(columnName, DoubleType.DOUBLE));
break;
default:
throw new ArrowException(ARROW_FLIGHT_ERROR, "Invalid floating point precision " + floatingPoint.getPrecision());
}
ArrowType.FloatingPoint floatingPointType = (ArrowType.FloatingPoint) type;
columnMetadata = createFloatingPointColumnMetadata(columnName, floatingPointType);
break;
case Decimal:
ArrowType.Decimal decimalType = (ArrowType.Decimal) field.getType();
int precision = decimalType.getPrecision();
int scale = decimalType.getScale();
meta.add(new ColumnMetadata(columnName, DecimalType.createDecimalType(precision, scale)));
ArrowType.Decimal decimalType = (ArrowType.Decimal) type;
columnMetadata = new ColumnMetadata(columnName, DecimalType.createDecimalType(decimalType.getPrecision(), decimalType.getScale()));
break;
case Time:
meta.add(new ColumnMetadata(columnName, TimeType.TIME));
columnMetadata = new ColumnMetadata(columnName, TimeType.TIME);
break;
case Bool:
meta.add(new ColumnMetadata(columnName, BooleanType.BOOLEAN));
columnMetadata = new ColumnMetadata(columnName, BooleanType.BOOLEAN);
break;
default:
throw new UnsupportedOperationException("The data type " + field.getType().getTypeID() + " is not supported.");
throw new UnsupportedOperationException("The data type " + type.getTypeID() + " is not supported.");
}

Type fieldType = overrideFieldType(field, columnMetadata.getType());
if (!fieldType.equals(columnMetadata.getType())) {
columnMetadata = new ColumnMetadata(columnName, fieldType);
}
meta.add(columnMetadata);
}
return new ConnectorTableMetadata(new SchemaTableName(((ArrowTableHandle) table).getSchema(), ((ArrowTableHandle) table).getTable()), meta);
}

private ColumnMetadata createIntColumnMetadata(String columnName, ArrowType.Int intType)
{
switch (intType.getBitWidth()) {
case 64:
return new ColumnMetadata(columnName, BigintType.BIGINT);
case 32:
return new ColumnMetadata(columnName, IntegerType.INTEGER);
case 16:
return new ColumnMetadata(columnName, SmallintType.SMALLINT);
case 8:
return new ColumnMetadata(columnName, TinyintType.TINYINT);
default:
throw new ArrowException(ARROW_FLIGHT_ERROR, "Invalid bit width " + intType.getBitWidth());
}
}

private ColumnMetadata createFloatingPointColumnMetadata(String columnName, ArrowType.FloatingPoint floatingPointType)
{
switch (floatingPointType.getPrecision()) {
case SINGLE:
return new ColumnMetadata(columnName, RealType.REAL);
case DOUBLE:
return new ColumnMetadata(columnName, DoubleType.DOUBLE);
default:
throw new ArrowException(ARROW_FLIGHT_ERROR, "Invalid floating point precision " + floatingPointType.getPrecision());
}
}

/**
* Provides the field type, which can be overridden by concrete implementations
* with their own custom type.
*
* @return the field type
*/
protected Type overrideFieldType(Field field, Type type)
{
return type;
}

@Override
public ColumnMetadata getColumnMetadata(ConnectorSession session, ConnectorTableHandle tableHandle, ColumnHandle columnHandle)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.facebook.presto.common.Page;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.type.CharType;
import com.facebook.presto.common.type.DateType;
import com.facebook.presto.common.type.DecimalType;
import com.facebook.presto.common.type.Decimals;
Expand All @@ -26,6 +27,7 @@
import com.facebook.presto.common.type.VarcharType;
import com.facebook.presto.spi.ConnectorPageSource;
import com.facebook.presto.spi.ConnectorSession;
import com.google.common.base.CharMatcher;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import org.apache.arrow.flight.FlightRuntimeException;
Expand Down Expand Up @@ -55,6 +57,8 @@

import java.math.BigDecimal;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.time.LocalTime;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
Expand All @@ -68,7 +72,6 @@ public class ArrowPageSource
private static final Logger logger = Logger.get(ArrowPageSource.class);
private final ArrowSplit split;
private final List<ArrowColumnHandle> columnHandles;
private final ArrowFlightClientHandler clientHandler;
private boolean completed;
private int currentPosition;
private Optional<VectorSchemaRoot> vectorSchemaRoot = Optional.empty();
Expand All @@ -80,7 +83,6 @@ public ArrowPageSource(ArrowSplit split, List<ArrowColumnHandle> columnHandles,
{
this.columnHandles = columnHandles;
this.split = split;
this.clientHandler = clientHandler;
getFlightStream(clientHandler, split.getTicket(), connectorSession);
}

Expand Down Expand Up @@ -151,9 +153,7 @@ public Page getNextPage()
@Override
public void close()
{
if (vectorSchemaRoot.isPresent()) {
vectorSchemaRoot.get().close();
}
vectorSchemaRoot.ifPresent(VectorSchemaRoot::close);
if (flightStream != null) {
try {
flightStream.close();
Expand All @@ -178,7 +178,7 @@ private void getFlightStream(ArrowFlightClientHandler clientHandler, byte[] tick
try {
Optional<String> uri = (split == null || split.getLocationUrls().isEmpty()) ?
Optional.empty() : Optional.of(split.getLocationUrls().get(0));
flightClient = clientHandler.getClient(uri);
ArrowFlightClient flightClient = clientHandler.getClient(uri);
flightStream = flightClient.getFlightClient().getStream(new Ticket(ticket), clientHandler.getCallOptions(connectorSession));
}
catch (FlightRuntimeException e) {
Expand Down Expand Up @@ -222,7 +222,15 @@ else if (vector instanceof Float8Vector) {
return buildBlockFromFloat8Vector((Float8Vector) vector, type);
}
else if (vector instanceof VarCharVector) {
return buildBlockFromVarCharVector((VarCharVector) vector, type);
if (type instanceof CharType) {
return buildBlockCharType((VarCharVector) vector, type);
}
else if (type instanceof TimeType) {
return buildBlockTimeType((VarCharVector) vector, type);
}
else {
return buildBlockFromVarCharVector((VarCharVector) vector, type);
}
}
else if (vector instanceof VarBinaryVector) {
return buildBlockFromVarBinaryVector((VarBinaryVector) vector, type);
Expand Down Expand Up @@ -572,4 +580,36 @@ private Block buildBlockFromTimeStampSecVector(TimeStampSecVector vector, Type t
}
return builder.build();
}

private Block buildBlockCharType(VarCharVector vector, Type type)
{
BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount());
for (int i = 0; i < vector.getValueCount(); i++) {
if (vector.isNull(i)) {
builder.appendNull();
}
else {
String value = new String(vector.get(i), StandardCharsets.UTF_8);
type.writeSlice(builder, Slices.utf8Slice(CharMatcher.is(' ').trimTrailingFrom(value)));
}
}
return builder.build();
}

private Block buildBlockTimeType(VarCharVector vector, Type type)
{
BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount());
for (int i = 0; i < vector.getValueCount(); i++) {
if (vector.isNull(i)) {
builder.appendNull();
}
else {
String timeString = new String(vector.get(i), StandardCharsets.UTF_8);
LocalTime time = LocalTime.parse(timeString);
long millis = Duration.between(LocalTime.MIN, time).toMillis();
type.writeLong(builder, millis);
}
}
return builder.build();
}
}

0 comments on commit d329e9f

Please sign in to comment.