Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Vector<T>.Indices and Vector.CreateSequence #97880

Merged
merged 3 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/coreclr/jit/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -3159,6 +3159,9 @@ class Compiler
GenTree* gtNewSimdCreateScalarUnsafeNode(
var_types type, GenTree* op1, CorInfoType simdBaseJitType, unsigned simdSize);

GenTree* gtNewSimdCreateSequenceNode(
var_types type, GenTree* op1, GenTree* op2, CorInfoType simdBaseJitType, unsigned simdSize);

GenTree* gtNewSimdDotProdNode(var_types type,
GenTree* op1,
GenTree* op2,
Expand All @@ -3174,6 +3177,8 @@ class Compiler
CorInfoType simdBaseJitType,
unsigned simdSize);

GenTree* gtNewSimdGetIndicesNode(var_types type, CorInfoType simdBaseJitType, unsigned simdSize);

GenTree* gtNewSimdGetLowerNode(var_types type,
GenTree* op1,
CorInfoType simdBaseJitType,
Expand Down
282 changes: 282 additions & 0 deletions src/coreclr/jit/gentree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22630,6 +22630,197 @@ GenTree* Compiler::gtNewSimdCreateScalarUnsafeNode(var_types type,
return gtNewSimdHWIntrinsicNode(type, op1, hwIntrinsicID, simdBaseJitType, simdSize);
}

//----------------------------------------------------------------------------------------------
// Compiler::gtNewSimdCreateSequenceNode: Creates a new simd CreateSequence node
//
// Arguments:
// type - The return type of SIMD node being created
// op1 - The starting value
// op2 - The step value
// simdBaseJitType - The base JIT type of SIMD type of the intrinsic
// simdSize - The size of the SIMD type of the intrinsic
//
// Returns:
// The created CreateSequence node
//
GenTree* Compiler::gtNewSimdCreateSequenceNode(
var_types type, GenTree* op1, GenTree* op2, CorInfoType simdBaseJitType, unsigned simdSize)
{
// This effectively doees: (Indices * op2) + Create(op1)
//
// When both op2 and op1 are constant we can fully fold this to a constant. Additionally,
// if only op2 is a constant we can simplify the computation by a lot. However, if only op1
// is constant than there isn't any real optimization we can do and we need the full computation.

assert(varTypeIsSIMD(type));
assert(getSIMDTypeForSize(simdSize) == type);

var_types simdBaseType = JitType2PreciseVarType(simdBaseJitType);
assert(varTypeIsArithmetic(simdBaseType));

GenTree* result = nullptr;
bool isPartial = true;

if (op2->OperIsConst())
{
GenTreeVecCon* vcon = gtNewVconNode(type);
uint32_t simdLength = getSIMDVectorLength(simdSize, simdBaseType);

switch (simdBaseType)
{
case TYP_BYTE:
case TYP_UBYTE:
{
uint8_t start = 0;

if (op1->OperIsConst())
{
assert(op1->IsIntegralConst());
start = static_cast<uint8_t>(op1->AsIntConCommon()->IntegralValue());
isPartial = false;
}

assert(op2->IsIntegralConst());
uint8_t step = static_cast<uint8_t>(op2->AsIntConCommon()->IntegralValue());

for (uint32_t index = 0; index < simdLength; index++)
{
vcon->gtSimdVal.u8[index] = static_cast<uint8_t>((index * step) + start);
}
break;
}

case TYP_SHORT:
case TYP_USHORT:
{
uint16_t start = 0;

if (op1->OperIsConst())
{
assert(op1->IsIntegralConst());
start = static_cast<uint16_t>(op1->AsIntConCommon()->IntegralValue());
isPartial = false;
}

assert(op2->IsIntegralConst());
uint16_t step = static_cast<uint16_t>(op2->AsIntConCommon()->IntegralValue());

for (uint32_t index = 0; index < simdLength; index++)
{
vcon->gtSimdVal.u16[index] = static_cast<uint16_t>((index * step) + start);
}
break;
}

case TYP_INT:
case TYP_UINT:
{
uint32_t start = 0;

if (op1->OperIsConst())
{
assert(op1->IsIntegralConst());
start = static_cast<uint32_t>(op1->AsIntConCommon()->IntegralValue());
isPartial = false;
}

assert(op2->IsIntegralConst());
uint32_t step = static_cast<uint32_t>(op2->AsIntConCommon()->IntegralValue());

for (uint32_t index = 0; index < simdLength; index++)
{
vcon->gtSimdVal.u32[index] = static_cast<uint32_t>((index * step) + start);
}
break;
}

case TYP_LONG:
case TYP_ULONG:
{
uint64_t start = 0;

if (op1->OperIsConst())
{
assert(op1->IsIntegralConst());
start = static_cast<uint64_t>(op1->AsIntConCommon()->IntegralValue());
isPartial = false;
}

assert(op2->IsIntegralConst());
uint64_t step = static_cast<uint64_t>(op2->AsIntConCommon()->IntegralValue());

for (uint32_t index = 0; index < simdLength; index++)
{
vcon->gtSimdVal.u64[index] = static_cast<uint64_t>((index * step) + start);
}
break;
}

case TYP_FLOAT:
{
float start = 0;

if (op1->OperIsConst())
{
assert(op1->IsCnsFltOrDbl());
start = static_cast<float>(op1->AsDblCon()->DconValue());
isPartial = false;
}

assert(op2->IsCnsFltOrDbl());
float step = static_cast<float>(op2->AsDblCon()->DconValue());

for (uint32_t index = 0; index < simdLength; index++)
{
vcon->gtSimdVal.f32[index] = static_cast<float>((index * step) + start);
}
break;
}

case TYP_DOUBLE:
{
double start = 0;

if (op1->OperIsConst())
{
assert(op1->IsCnsFltOrDbl());
start = static_cast<double>(op1->AsDblCon()->DconValue());
isPartial = false;
}

assert(op2->IsCnsFltOrDbl());
double step = static_cast<double>(op2->AsDblCon()->DconValue());

for (uint32_t index = 0; index < simdLength; index++)
{
vcon->gtSimdVal.f64[index] = static_cast<double>((index * step) + start);
}
break;
}

default:
{
unreached();
}
}

result = vcon;
}
else
{
GenTree* indices = gtNewSimdGetIndicesNode(type, simdBaseJitType, simdSize);
result = gtNewSimdBinOpNode(GT_MUL, type, indices, op2, simdBaseJitType, simdSize);
}

if (isPartial)
{
GenTree* start = gtNewSimdCreateBroadcastNode(type, op1, simdBaseJitType, simdSize);
result = gtNewSimdBinOpNode(GT_ADD, type, result, start, simdBaseJitType, simdSize);
}

return result;
}

GenTree* Compiler::gtNewSimdDotProdNode(
var_types type, GenTree* op1, GenTree* op2, CorInfoType simdBaseJitType, unsigned simdSize)
{
Expand Down Expand Up @@ -22812,6 +23003,97 @@ GenTree* Compiler::gtNewSimdGetElementNode(
return gtNewSimdHWIntrinsicNode(type, op1, op2, intrinsicId, simdBaseJitType, simdSize);
}

//----------------------------------------------------------------------------------------------
// Compiler::gtNewSimdGetIndicesNode: Creates a new simd get_Indices node
//
// Arguments:
// type - The return type of SIMD node being created
// simdBaseJitType - The base JIT type of SIMD type of the intrinsic
// simdSize - The size of the SIMD type of the intrinsic
//
// Returns:
// The created get_Indices node
//
GenTree* Compiler::gtNewSimdGetIndicesNode(var_types type, CorInfoType simdBaseJitType, unsigned simdSize)
{
assert(varTypeIsSIMD(type));
assert(getSIMDTypeForSize(simdSize) == type);

var_types simdBaseType = JitType2PreciseVarType(simdBaseJitType);
assert(varTypeIsArithmetic(simdBaseType));

GenTreeVecCon* indices = gtNewVconNode(type);
uint32_t simdLength = getSIMDVectorLength(simdSize, simdBaseType);

switch (simdBaseType)
{
case TYP_BYTE:
case TYP_UBYTE:
{
for (uint32_t index = 0; index < simdLength; index++)
{
indices->gtSimdVal.u8[index] = static_cast<uint8_t>(index);
}
break;
}

case TYP_SHORT:
case TYP_USHORT:
{
for (uint32_t index = 0; index < simdLength; index++)
{
indices->gtSimdVal.u16[index] = static_cast<uint16_t>(index);
}
break;
}

case TYP_INT:
case TYP_UINT:
{
for (uint32_t index = 0; index < simdLength; index++)
{
indices->gtSimdVal.u32[index] = static_cast<uint32_t>(index);
}
break;
}

case TYP_LONG:
case TYP_ULONG:
{
for (uint32_t index = 0; index < simdLength; index++)
{
indices->gtSimdVal.u64[index] = static_cast<uint64_t>(index);
}
break;
}

case TYP_FLOAT:
{
for (uint32_t index = 0; index < simdLength; index++)
{
indices->gtSimdVal.f32[index] = static_cast<float>(index);
}
break;
}

case TYP_DOUBLE:
{
for (uint32_t index = 0; index < simdLength; index++)
{
indices->gtSimdVal.f64[index] = static_cast<double>(index);
}
break;
}

default:
{
unreached();
}
}

return indices;
}

GenTree* Compiler::gtNewSimdGetLowerNode(var_types type, GenTree* op1, CorInfoType simdBaseJitType, unsigned simdSize)
{
var_types simdBaseType = JitType2PreciseVarType(simdBaseJitType);
Expand Down
21 changes: 21 additions & 0 deletions src/coreclr/jit/hwintrinsicarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,27 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
break;
}

case NI_Vector64_CreateSequence:
case NI_Vector128_CreateSequence:
{
assert(sig->numArgs == 2);

if (varTypeIsLong(simdBaseType) && !impStackTop(0).val->OperIsConst())
{
// TODO-ARM64-CQ: We should support long/ulong multiplication.
break;
}

impSpillSideEffect(true, verCurrentState.esStackDepth -
2 DEBUGARG("Spilling op1 side effects for SimdAsHWIntrinsic"));

op2 = impPopStack().val;
op1 = impPopStack().val;

retNode = gtNewSimdCreateSequenceNode(retType, op1, op2, simdBaseJitType, simdSize);
break;
}

case NI_Vector64_CreateScalarUnsafe:
case NI_Vector128_CreateScalarUnsafe:
{
Expand Down
Loading
Loading