diff --git a/lib/ExecutionEngine/MemRefUtils.cpp b/lib/ExecutionEngine/MemRefUtils.cpp index e34bf4455ab9..04b86216dbde 100644 --- a/lib/ExecutionEngine/MemRefUtils.cpp +++ b/lib/ExecutionEngine/MemRefUtils.cpp @@ -27,6 +27,7 @@ #include "llvm/Support/Error.h" #include +#include using namespace mlir; @@ -45,12 +46,14 @@ allocMemRefDescriptor(Type type, bool allocateData = true, return make_string_error("memref with dynamic shapes not supported"); auto elementType = memRefType.getElementType(); - if (!elementType.isF32()) + VectorType vectorType = elementType.dyn_cast(); + if (!elementType.isF32() && + !(vectorType && vectorType.getElementType().isF32())) return make_string_error( - "memref with element other than f32 not supported"); + "memref with element other than f32 or vector of f32 not supported"); auto *descriptor = - reinterpret_cast(malloc(sizeof(StaticFloatMemRef))); + static_cast(malloc(sizeof(StaticFloatMemRef))); if (!allocateData) { descriptor->data = nullptr; return descriptor; @@ -59,7 +62,18 @@ allocMemRefDescriptor(Type type, bool allocateData = true, auto shape = memRefType.getShape(); int64_t size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); - descriptor->data = reinterpret_cast(malloc(sizeof(float) * size)); + // Align vector of f32 to the vector size boundary (to the closest greater + // power of two if the former isn't a power of two). + if (vectorType) { + int64_t numElements = vectorType.getNumElements(); + size *= numElements; + size_t alignment = llvm::PowerOf2Ceil(numElements * sizeof(float)); + posix_memalign(reinterpret_cast(&descriptor->data), alignment, + size * sizeof(float)); + } else { + descriptor->data = static_cast(malloc(sizeof(float) * size)); + } + for (int64_t i = 0; i < size; ++i) { descriptor->data[i] = initialValue; } diff --git a/lib/Support/JitRunner.cpp b/lib/Support/JitRunner.cpp index f87664d621a6..33ef3eb844bf 100644 --- a/lib/Support/JitRunner.cpp +++ b/lib/Support/JitRunner.cpp @@ -141,6 +141,10 @@ static void printOneMemRef(Type t, void *val) { auto shape = memRefType.getShape(); int64_t size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + if (auto vectorType = memRefType.getElementType().dyn_cast()) { + size *= vectorType.getNumElements(); + } + for (int64_t i = 0; i < size; ++i) { llvm::outs() << reinterpret_cast(val)->data[i] << ' '; } diff --git a/test/mlir-cpu-runner/simple.mlir b/test/mlir-cpu-runner/simple.mlir index afcf78eb80a2..2cbea8f9262b 100644 --- a/test/mlir-cpu-runner/simple.mlir +++ b/test/mlir-cpu-runner/simple.mlir @@ -2,6 +2,8 @@ // RUN: mlir-cpu-runner -e foo -init-value 1000 %s | FileCheck -check-prefix=NOMAIN %s // RUN: mlir-cpu-runner %s -O3 | FileCheck %s // RUN: mlir-cpu-runner -e affine -init-value 2.0 %s | FileCheck -check-prefix=AFFINE %s +// RUN: mlir-cpu-runner -e bar -init-value 2.0 %s | FileCheck -check-prefix=BAR %s +// RUN: mlir-cpu-runner -e large_vec_memref -init-value 2.0 %s | FileCheck -check-prefix=LARGE-VEC %s // RUN: cp %s %t // RUN: mlir-cpu-runner %t -dump-object-file | FileCheck %t @@ -49,3 +51,27 @@ func @affine(%a : memref<32xf32>) -> memref<32xf32> { return %a : memref<32xf32> } // AFFINE: 4.2{{0+}}e+01 + +func @bar(%a : memref<16xvector<4xf32>>) -> memref<16xvector<4xf32>> { + %c0 = constant 0 : index + %c1 = constant 1 : index + + %u = load %a[%c0] : memref<16xvector<4xf32>> + %v = load %a[%c1] : memref<16xvector<4xf32>> + %w = addf %u, %v : vector<4xf32> + store %w, %a[%c0] : memref<16xvector<4xf32>> + + return %a : memref<16xvector<4xf32>> +} +// BAR: 4.{{0+}}e+00 4.{{0+}}e+00 4.{{0+}}e+00 4.{{0+}}e+00 2.{{0+}}e+00 +// BAR-NEXT: 4.{{0+}}e+00 4.{{0+}}e+00 4.{{0+}}e+00 4.{{0+}}e+00 2.{{0+}}e+00 + +func @large_vec_memref(%arg2: memref<128x128xvector<8xf32>>) -> memref<128x128xvector<8xf32>> { + %c0 = constant 0 : index + %c127 = constant 127 : index + %v = constant dense<42.0> : vector<8xf32> + store %v, %arg2[%c0, %c0] : memref<128x128xvector<8xf32>> + store %v, %arg2[%c127, %c127] : memref<128x128xvector<8xf32>> + return %arg2 : memref<128x128xvector<8xf32>> +} +// LARGE-VEC: 4.200000e+01