Skip to content

[MLIR][NVVM] Support stmatrix intrinsics #148377

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 29 additions & 18 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1990,32 +1990,43 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">,
let hasVerifier = 1;
}

def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">,
Arguments<(ins LLVM_PointerShared:$ptr,
Variadic<I32>:$sources,
MMALayoutAttr:$layout)> {
def LdStMatrixShapeAttr : NVVM_Attr<"LdStMatrixShape", "ld_st_matrix_shape"> {
let summary = "Matrix shape for ldmatrix and stmatrix";
let parameters = (ins "int":$m, "int":$n);
let assemblyFormat = "`<` struct(params) `>`";
}

def LdStMatrixEltTypeB16 : I32EnumAttrCase<"B16", 0, "b16">;
def LdStMatrixEltTypeB8 : I32EnumAttrCase<"B8", 1, "b8">;
def LdStMatrixEltTypeB8X16_B6X16_P32 : I32EnumAttrCase<"B8X16_B6X16_P32", 2, "b8x16.b6x16_p32">;
def LdStMatrixEltTypeB8X16_B4X16_P64 : I32EnumAttrCase<"B8X16_B4X16_P64", 3, "b8x16.b4x16_p64">;

def LdStMatrixEltType : I32EnumAttr<"LdStMatrixEltType", "Element type for ldmatrix and stmatrix",
[LdStMatrixEltTypeB16, LdStMatrixEltTypeB8,
LdStMatrixEltTypeB8X16_B6X16_P32, LdStMatrixEltTypeB8X16_B4X16_P64]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::NVVM";
}
def LdStMatrixEltTypeAttr : EnumAttr<NVVM_Dialect, LdStMatrixEltType, "ld_st_matrix_elttype"> {
let assemblyFormat = "`<` $value `>`";
}

def NVVM_StMatrixOp: NVVM_Op<"stmatrix">,
Arguments<(ins LLVM_AnyPointer: $ptr, Variadic<I32>:$sources, MMALayoutAttr:$layout,
LdStMatrixShapeAttr:$shape, LdStMatrixEltTypeAttr:$elttype)> {
let summary = "cooperative matrix store";
let description = [{
Collectively store one or more matrices across all threads in a warp to the
location indicated by the address operand $ptr in shared memory.

[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix)
}];

let assemblyFormat = "$ptr `,` $sources attr-dict `:` type(operands)";
let extraClassDefinition = [{
std::string $cppClass::getPtx() {
int d = getSources().size();
std::string ptx = "stmatrix.sync.aligned";
ptx += ".x" + std::to_string(d);
if (getLayout() == NVVM::MMALayout::col)
ptx += ".trans";
if(d == 1) ptx += ".m8n8.shared.b16 [%0], {%1};";
if(d == 2) ptx += ".m8n8.shared.b16 [%0], {%1, %2};";
if(d == 4) ptx += ".m8n8.shared.b16 [%0], {%1, %2, %3, %4};";
return ptx;
}
string llvmBuilder = [{
auto operands = moduleTranslation.lookupValues(opInst.getOperands());
auto intId = getStMatrixIntrinsicId($layout, $sources.size(), $shape, $elttype);
createIntrinsicCall(builder, intId, operands, operands[0]->getType());
}];
let assemblyFormat = "$ptr `,` $sources attr-dict `:` type(operands)";
let hasVerifier = 1;
}

Expand Down
44 changes: 44 additions & 0 deletions mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,50 @@ static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
}
}

/// Return the intrinsic ID associated with stmatrix for the given paramters.
static llvm::Intrinsic::ID
getStMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
NVVM::LdStMatrixShapeAttr shape,
NVVM::LdStMatrixEltType eltType) {
if (shape.getM() == 8 && shape.getN() == 8) {
if (eltType == NVVM::LdStMatrixEltType::B16) {
if (layout == NVVM::MMALayout::row) {
switch (num) {
case 1:
return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16;
case 2:
return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16;
case 4:
return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16;
}
} else {
switch (num) {
case 1:
return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16;
case 2:
return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16;
case 4:
return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16;
}
}
}
} else if (shape.getM() == 16 && shape.getN() == 8) {
if (eltType == NVVM::LdStMatrixEltType::B8) {
if (layout == NVVM::MMALayout::col) {
switch (num) {
case 1:
return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8;
case 2:
return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8;
case 4:
return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8;
}
}
}
}
llvm_unreachable("unknown stmatrix kind");
}

/// Return the intrinsic ID associated with st.bulk for the given address type.
static llvm::Intrinsic::ID
getStBulkIntrinsicId(LLVM::LLVMPointerType addrType) {
Expand Down
24 changes: 0 additions & 24 deletions mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -580,30 +580,6 @@ func.func @elect_one_leader_sync() {

// -----

// CHECK-LABEL: @stmatrix(
// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !llvm.ptr<3>,
// CHECK-SAME: %[[arg1:[a-zA-Z0-9_]+]]: i32,
// CHECK-SAME: %[[arg2:[a-zA-Z0-9_]+]]: i32,
// CHECK-SAME: %[[arg3:[a-zA-Z0-9_]+]]: i32,
// CHECK-SAME: %[[arg4:[a-zA-Z0-9_]+]]: i32)
llvm.func @stmatrix(%arg0 : !llvm.ptr<3>, %m1 : i32, %m2 : i32, %m3 : i32, %m4 : i32) {
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.m8n8.shared.b16 [$0], {$1};", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> ()
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.m8n8.shared.b16 [$0], {$1, $2};", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> ()
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> ()
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [$0], {$1};", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> ()
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [$0], {$1, $2};", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> ()
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> ()
nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32
nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32
nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32, i32, i32
nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32
nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32
nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32, i32, i32
llvm.return
}

// -----

// CHECK-LABEL: @init_mbarrier_arrive_expect_tx
llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) {
//CHECK: llvm.inline_asm has_side_effects asm_dialect = att "prefetch.tensormap [$0];", "l"
Expand Down
23 changes: 23 additions & 0 deletions mlir/test/Target/LLVMIR/nvvmir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,29 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
llvm.return
}

// CHECK-LABEL: @st_matrix
llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
// CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : !llvm.ptr<3>, i32
// CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : !llvm.ptr<3>, i32
// CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x1.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b8>} : !llvm.ptr<3>, i32
// CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x2.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : !llvm.ptr<3>, i32, i32
// CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x2.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : !llvm.ptr<3>, i32, i32
// CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x2.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b8>} : !llvm.ptr<3>, i32, i32
// CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x4.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : !llvm.ptr<3>, i32, i32, i32, i32
// CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x4.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : !llvm.ptr<3>, i32, i32, i32, i32
// CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x4.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b8>} : !llvm.ptr<3>, i32, i32, i32, i32
llvm.return
}

// This function has the "kernel" attribute attached and should appear in the
// NVVM annotations after conversion.
llvm.func @kernel_func() attributes {nvvm.kernel} {
Expand Down