Skip to content

[MLIR] [OpenMP] Initial support for OMP ALLOCATE directive op. #147900

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,31 @@
include "mlir/Dialect/OpenMP/OpenMPOpBase.td"
include "mlir/IR/SymbolInterfaces.td"

//===----------------------------------------------------------------------===//
// V5.2: [6.3] `align` clause
//===----------------------------------------------------------------------===//

class OpenMP_AlignClauseSkip<
bit traits = false, bit arguments = false, bit assemblyFormat = false,
bit description = false, bit extraClassDeclaration = false
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let arguments = (ins
ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$align
);

let optAssemblyFormat = [{
`align` `(` $align `)`
}];

let description = [{
The `align` clause is used to specify the byte alignment to use for
allocations associated with the construct on which the clause appears.
}];
}

def OpenMP_AlignClause : OpenMP_AlignClauseSkip<>;

//===----------------------------------------------------------------------===//
// V5.2: [5.11] `aligned` clause
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -84,6 +109,32 @@ class OpenMP_AllocateClauseSkip<

def OpenMP_AllocateClause : OpenMP_AllocateClauseSkip<>;

//===----------------------------------------------------------------------===//
// V5.2: [6.4] `allocator` clause
//===----------------------------------------------------------------------===//

class OpenMP_AllocatorClauseSkip<
bit traits = false, bit arguments = false, bit assemblyFormat = false,
bit description = false, bit extraClassDeclaration = false
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {

let arguments = (ins
OptionalAttr<AllocatorHandleAttr>:$allocator
);

let optAssemblyFormat = [{
`allocator` `(` custom<ClauseAttr>($allocator) `)`
}];

let description = [{
`allocator` specifies the memory allocator to be used for allocations
associated with the construct on which the clause appears.
}];
}

def OpenMP_AllocatorClause : OpenMP_AllocatorClauseSkip<>;

//===----------------------------------------------------------------------===//
// LLVM OpenMP extension `ompx_bare` clause
//===----------------------------------------------------------------------===//
Expand Down
30 changes: 30 additions & 0 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -263,4 +263,34 @@ def VariableCaptureKindAttr : OpenMP_EnumAttr<VariableCaptureKind,
let assemblyFormat = "`(` $value `)`";
}


//===----------------------------------------------------------------------===//
// allocator_handle enum.
//===----------------------------------------------------------------------===//

def OpenMP_AllocatorHandleNullAllocator : I32EnumAttrCase<"omp_null_allocator", 0>;
def OpenMP_AllocatorHandleDefaultMemAlloc : I32EnumAttrCase<"omp_default_mem_alloc", 1>;
def OpenMP_AllocatorHandleLargeCapMemAlloc : I32EnumAttrCase<"omp_large_cap_mem_alloc", 2>;
def OpenMP_AllocatorHandleConstMemAlloc : I32EnumAttrCase<"omp_const_mem_alloc", 3>;
def OpenMP_AllocatorHandleHighBwMemAlloc : I32EnumAttrCase<"omp_high_bw_mem_alloc", 4>;
def OpenMP_AllocatorHandleLowLatMemAlloc : I32EnumAttrCase<"omp_low_lat_mem_alloc", 5>;
def OpenMP_AllocatorHandleCgroupMemAlloc : I32EnumAttrCase<"omp_cgroup_mem_alloc", 6>;
def OpenMP_AllocatorHandlePteamMemAlloc : I32EnumAttrCase<"omp_pteam_mem_alloc", 7>;
def OpenMP_AllocatorHandlethreadMemAlloc : I32EnumAttrCase<"omp_thread_mem_alloc", 8>;

def AllocatorHandle : OpenMP_I32EnumAttr<
"AllocatorHandle",
"OpenMP allocator_handle", [
OpenMP_AllocatorHandleNullAllocator,
OpenMP_AllocatorHandleDefaultMemAlloc,
OpenMP_AllocatorHandleLargeCapMemAlloc,
OpenMP_AllocatorHandleConstMemAlloc,
OpenMP_AllocatorHandleHighBwMemAlloc,
OpenMP_AllocatorHandleLowLatMemAlloc,
OpenMP_AllocatorHandleCgroupMemAlloc,
OpenMP_AllocatorHandlePteamMemAlloc,
OpenMP_AllocatorHandlethreadMemAlloc
]>;

def AllocatorHandleAttr : OpenMP_EnumAttr<AllocatorHandle, "allocator_handle">;
#endif // OPENMP_ENUMS
23 changes: 23 additions & 0 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2090,4 +2090,27 @@ def MaskedOp : OpenMP_Op<"masked", clauses = [
];
}

//===----------------------------------------------------------------------===//
// [Spec 5.2] 6.5 allocate Directive
//===----------------------------------------------------------------------===//
def AllocateDirOp : OpenMP_Op<"allocate_dir", clauses = [
OpenMP_AlignClause, OpenMP_AllocatorClause
]> {
let summary = "allocate directive";
let description = [{
The storage for each list item that appears in the allocate directive is
provided an allocation through the memory allocator.
}] # clausesDescription;

let arguments = !con((ins Variadic<AnyType>:$varList),
clausesArgs);

// Override inherited assembly format to include `varList`.
let assemblyFormat = " `(` $varList `:` type($varList) `)` oilist(" #
clausesOptAssemblyFormat #
") attr-dict ";

let hasVerifier = 1;
}

#endif // OPENMP_OPS
15 changes: 15 additions & 0 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/bit.h"
#include "llvm/Frontend/OpenMP/OMPConstants.h"
#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
#include <cstddef>
Expand Down Expand Up @@ -3863,6 +3864,20 @@ LogicalResult ScanOp::verify() {
"reduction modifier");
}

/// Verifies align clause in allocate directive

LogicalResult AllocateDirOp::verify() {
std::optional<u_int64_t> align = this->getAlign();

if (align.has_value()) {
if ((align.value() > 0) && !llvm::has_single_bit(align.value()))
return emitError() << "ALIGN value : " << align.value()
<< " must be power of 2";
}

return success();
}

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"

Expand Down
24 changes: 24 additions & 0 deletions mlir/test/Dialect/OpenMP/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2993,3 +2993,27 @@ llvm.func @invalid_mapper(%0 : !llvm.ptr) {
}
llvm.return
}

// -----
func.func @invalid_allocate_align_1(%arg0 : memref<i32>) -> () {
// expected-error @below {{failed to satisfy constraint: 64-bit signless integer attribute whose value is positive}}
omp.allocate_dir (%arg0 : memref<i32>) align(-1)

return
}

// -----
func.func @invalid_allocate_align_2(%arg0 : memref<i32>) -> () {
// expected-error @below {{must be power of 2}}
omp.allocate_dir (%arg0 : memref<i32>) align(3)

return
}

// -----
func.func @invalid_allocate_allocator(%arg0 : memref<i32>) -> () {
// expected-error @below {{invalid clause value}}
omp.allocate_dir (%arg0 : memref<i32>) allocator(omp_small_cap_mem_alloc)

return
}
33 changes: 33 additions & 0 deletions mlir/test/Dialect/OpenMP/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3197,3 +3197,36 @@ func.func @omp_workshare_loop_wrapper_attrs(%idx : index) {
}
return
}

// CHECK-LABEL: func.func @omp_allocate_dir(
// CHECK-SAME: %[[ARG0:.*]]: memref<i32>,
// CHECK-SAME: %[[ARG1:.*]]: memref<i32>) {
func.func @omp_allocate_dir(%arg0 : memref<i32>, %arg1 : memref<i32>) -> () {

// Test with one data var
// CHECK: omp.allocate_dir(%[[ARG0]] : memref<i32>)
omp.allocate_dir (%arg0 : memref<i32>)

// Test with two data vars
// CHECK: omp.allocate_dir(%[[ARG0]], %[[ARG1]] : memref<i32>, memref<i32>)
omp.allocate_dir (%arg0, %arg1: memref<i32>, memref<i32>)

// Test with one data var and align clause
// CHECK: omp.allocate_dir(%[[ARG0]] : memref<i32>) align(2)
omp.allocate_dir (%arg0 : memref<i32>) align(2)

// Test with one data var and allocator clause
// CHECK: omp.allocate_dir(%[[ARG0]] : memref<i32>) allocator(omp_pteam_mem_alloc)
omp.allocate_dir (%arg0 : memref<i32>) allocator(omp_pteam_mem_alloc)

// Test with one data var, align clause and allocator clause
// CHECK: omp.allocate_dir(%[[ARG0]] : memref<i32>) align(2) allocator(omp_thread_mem_alloc)
omp.allocate_dir (%arg0 : memref<i32>) align(2) allocator(omp_thread_mem_alloc)

// Test with two data vars, align clause and allocator clause
// CHECK: omp.allocate_dir(%[[ARG0]], %[[ARG1]] : memref<i32>, memref<i32>) align(2) allocator(omp_cgroup_mem_alloc)
omp.allocate_dir (%arg0, %arg1 : memref<i32>, memref<i32>) align(2) allocator(omp_cgroup_mem_alloc)

return
}

Loading