From dacee0847b660af16f5c8f45bcb83708d6916b0c Mon Sep 17 00:00:00 2001 From: Gao Yanfeng Date: Fri, 11 Jul 2025 22:51:45 +0800 Subject: [PATCH 1/2] [MLIR] Fix the PTX generation bug for StMatrixOp --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 6895e946b8a45..b27c03ec2c63f 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -2000,13 +2000,13 @@ def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">, let extraClassDefinition = [{ std::string $cppClass::getPtx() { int d = getSources().size(); - std::string ptx = "stmatrix.sync.aligned"; + std::string ptx = "stmatrix.sync.aligned.m8n8"; ptx += ".x" + std::to_string(d); if (getLayout() == NVVM::MMALayout::col) ptx += ".trans"; - if(d == 1) ptx += ".m8n8.shared.b16 [%0], {%1};"; - if(d == 2) ptx += ".m8n8.shared.b16 [%0], {%1, %2};"; - if(d == 4) ptx += ".m8n8.shared.b16 [%0], {%1, %2, %3, %4};"; + if(d == 1) ptx += ".shared.b16 [%0], {%1};"; + if(d == 2) ptx += ".shared.b16 [%0], {%1, %2};"; + if(d == 4) ptx += ".shared.b16 [%0], {%1, %2, %3, %4};"; return ptx; } }]; From 8ed67ea21ae58db16b60e80486f040088bacf212 Mon Sep 17 00:00:00 2001 From: Gao Yanfeng Date: Fri, 11 Jul 2025 23:38:08 +0800 Subject: [PATCH 2/2] Fix the testcase --- mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir index 8d720ce62a91b..b2073edfc6fd3 100644 --- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir +++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir @@ -587,12 +587,12 @@ func.func @elect_one_leader_sync() { // CHECK-SAME: %[[arg3:[a-zA-Z0-9_]+]]: i32, // CHECK-SAME: %[[arg4:[a-zA-Z0-9_]+]]: i32) llvm.func @stmatrix(%arg0 : !llvm.ptr<3>, %m1 : i32, %m2 : i32, %m3 : i32, %m4 : i32) { -// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.m8n8.shared.b16 [$0], {$1};", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> () -// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.m8n8.shared.b16 [$0], {$1, $2};", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> () -// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> () -// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [$0], {$1};", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> () -// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [$0], {$1, $2};", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> () -// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> () +// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.m8n8.x1.shared.b16 [$0], {$1};", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> () +// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.m8n8.x2.shared.b16 [$0], {$1, $2};", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> () +// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.m8n8.x4.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> () +// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.m8n8.x1.trans.shared.b16 [$0], {$1};", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> () +// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.m8n8.x2.trans.shared.b16 [$0], {$1, $2};", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> () +// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.m8n8.x4.trans.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> () nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout} : !llvm.ptr<3>, i32 nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout} : !llvm.ptr<3>, i32, i32 nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout} : !llvm.ptr<3>, i32, i32, i32, i32