From 5aed821f5dab515a59905342ec09b83bc6df336d Mon Sep 17 00:00:00 2001 From: Gao Yanfeng Date: Sat, 12 Jul 2025 23:17:41 +0800 Subject: [PATCH 1/4] [MLIR][NVVM][NVGPU] Support intrinsics about stmatrix --- llvm/include/llvm/IR/IntrinsicsNVVM.td | 65 +++++++++ llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 29 +++- llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 45 ++++++- llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py | 14 ++ llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py | 4 +- llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py | 4 +- llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py | 4 +- llvm/test/CodeGen/NVPTX/wmma.py | 125 ++++++++++++++++++ mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 39 +++--- .../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp | 43 ++++++ .../Conversion/NVVMToLLVM/nvvm-to-llvm.mlir | 24 ---- mlir/test/Target/LLVMIR/nvvmir.mlir | 23 ++++ 12 files changed, 365 insertions(+), 54 deletions(-) create mode 100644 llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td index 0375f29ad8906..aad21fd4cba1c 100644 --- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -331,6 +331,11 @@ class WMMA_REGS { !eq(gf,"m8n16:x2") : !listsplat(llvm_i32_ty, 2), !eq(gf,"m8n16:x4") : !listsplat(llvm_i32_ty, 4), + // stmatrix b8 -> s32 @ m16n8 + !eq(gf,"m16n8:x1") : !listsplat(llvm_i32_ty, 1), + !eq(gf,"m16n8:x2") : !listsplat(llvm_i32_ty, 2), + !eq(gf,"m16n8:x4") : !listsplat(llvm_i32_ty, 4), + ); } @@ -403,6 +408,17 @@ class LDMATRIX_NAME { !subst("llvm.", "int_", intr)); } +class STMATRIX_NAME { + string intr = "llvm.nvvm.stmatrix.sync.aligned" + # "." # Frag.geom + # "." # Frag.frag + # !if(Trans, ".trans", "") + # "." # Frag.ptx_elt_type + ; + string record = !subst(".", "_", + !subst("llvm.", "int_", intr)); +} + // Generates list of 4-tuples of WMMA_REGS representing a valid MMA op. // Geom: list of supported geometries. // TypeN: PTX type of the corresponding fragment's element. @@ -443,6 +459,16 @@ class LDMATRIX_OPS Geom, list Frags, list Types> { list ops = !foreach(x, ret, x.gft); } +class STMATRIX_OPS Geom, list Frags, list Types> { + list ret = + !foldl([], Geom, t1, geom, !listconcat(t1, + !foldl([], Frags, t2, frag, !listconcat(t2, + !foldl([], Types, t3, type, !listconcat(t3, + [WMMA_REGS])))))); + // Debugging aid for readable representation of the list above. + list ops = !foreach(x, ret, x.gft); +} + // Creates list of valid combinations of fragments. This is the main list that // drives generation of corresponding intrinsics and instructions. class NVVM_MMA_OPS { @@ -537,9 +563,18 @@ class NVVM_MMA_OPS { list ldmatrix_geom_m8n16_ops = LDMATRIX_OPS< ["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret; + list stmatrix_b16_ops = STMATRIX_OPS< + ["m8n8"], ["x1", "x2", "x4"], ["b16"]>.ret; + + list stmatrix_b8_ops = STMATRIX_OPS< + ["m16n8"], ["x1", "x2", "x4"], ["b8"]>.ret; + list all_ldmatrix_ops = !listconcat(ldmatrix_b16_ops, ldmatrix_geom_m16n16_ops, ldmatrix_geom_m8n16_ops); + + list all_stmatrix_ops = !listconcat(stmatrix_b16_ops, + stmatrix_b8_ops); } def NVVM_MMA_OPS : NVVM_MMA_OPS; @@ -680,6 +715,19 @@ class NVVM_LDMATRIX_SUPPORTED { ); } +// Returns true if the fragment is valid for stmatrix ops is supported; +// false otherwise. +class NVVM_STMATRIX_SUPPORTED { + string g = frag.geom; + string t = frag.ptx_elt_type; + + bit ret = !cond( + !and(!eq(g, "m8n8"), !eq(t, "b16")): true, + !and(!eq(g, "m16n8"), !eq(t, "b8"), !eq(trans, 1)): true, + true: false + ); +} + class SHFL_INFO { string Suffix = !if(sync, "sync_", "") # mode # "_" @@ -1969,6 +2017,23 @@ foreach transposed = [0, 1] in { } } +// STMATRIX +class NVVM_STMATRIX + : Intrinsic<[], + !listconcat([llvm_anyptr_ty], Frag.regs), + [IntrWriteMem, IntrArgMemOnly, IntrNoCallback, + WriteOnly>, NoCapture>], + STMATRIX_NAME.intr>; + +foreach transposed = [0, 1] in { + foreach frag = NVVM_MMA_OPS.all_stmatrix_ops in { + if NVVM_STMATRIX_SUPPORTED.ret then { + def STMATRIX_NAME.record + : NVVM_STMATRIX; + } + } +} + // MAPA let IntrProperties = [IntrNoMem, IntrSpeculatable, NoCapture>] in { def int_nvvm_mapa diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 3d010e04824c5..d94be492b0c02 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -3952,7 +3952,10 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic( case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col: case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col_stride: case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row: - case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride: { + case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride: + case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16: + case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16: + case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8: { Info.opc = ISD::INTRINSIC_VOID; Info.memVT = MVT::v2i32; Info.ptrVal = I.getArgOperand(0); @@ -3975,6 +3978,30 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic( return true; } + case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16: + case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16: + case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8: { + Info.opc = ISD::INTRINSIC_VOID; + Info.memVT = MVT::i32; + Info.ptrVal = I.getArgOperand(0); + Info.offset = 0; + Info.flags = MachineMemOperand::MOStore; + Info.align = Align(4); + return true; + } + + case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16: + case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16: + case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8: { + Info.opc = ISD::INTRINSIC_VOID; + Info.memVT = MVT::v4i32; + Info.ptrVal = I.getArgOperand(0); + Info.offset = 0; + Info.flags = MachineMemOperand::MOStore; + Info.align = Align(16); + return true; + } + case Intrinsic::nvvm_atomic_add_gen_f_cta: case Intrinsic::nvvm_atomic_add_gen_f_sys: case Intrinsic::nvvm_atomic_add_gen_i_cta: diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index 93827be5c2811..1e24bf8ab99e1 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -4597,7 +4597,14 @@ class WMMA_REGINFO !and(!eq(op, "ldmatrix"), !eq(ptx_elt_type, "b8x16.b4x16_p64"), - !eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]); + !eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>], + + !and(!eq(op, "stmatrix"),!eq(ptx_elt_type, "b16"), + !eq(geom, "m8n8")) : [hasSM<90>, hasPTX<78>], + + !and(!eq(op, "stmatrix"), + !eq(ptx_elt_type, "b8"), + !eq(geom, "m16n8")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]); // template DAGs for instruction inputs/output. dag Outs = !dag(outs, ptx_regs, reg_names); @@ -4878,6 +4885,40 @@ defset list LDMATRIXs = { } // transposed } // defset +// +// stmatrix.sync.aligned.m8n8[|.trans][|.shared].b16 +// +class STMATRIX + : WMMA_INSTR.record, [!con((ins ADDR:$dst), Frag.Ins)]>, + Requires { + // Build PatFrag that only matches particular address space. + dag PFOperands = !con((ops node:$dst), !dag(ops, !listsplat(node, !size(Frag.regs)), Frag.reg_names)); + PatFrag IntrFrag = PatFrag; + // Build AS-constrained pattern. + let IntrinsicPattern = BuildPatternPF.ret; + let OutOperandList = (outs); + let InOperandList = !con(Args, (ins MmaCode:$ptx)); + let AsmString = "stmatrix.sync.aligned." + # Frag.geom + # "." # Frag.frag + # !if(Transposed, ".trans", "") + # Space + # "." # Frag.ptx_elt_type + # " [$dst], " # Frag.regstring # ";"; +} + +// Create all stmatrix variants +defset list STMATRIXs = { + foreach transposed = [false, true] in {foreach space = [".shared", ""] in { + foreach frag = NVVM_MMA_OPS.all_stmatrix_ops in + if NVVM_STMATRIX_SUPPORTED.ret then + def : STMATRIX, transposed, space>; + } // space + } // transposed +} // defset + // Constructing non-flat DAGs is still a pain. I can't !subst a dag node with a // dag, so the ptx.version must be appended *after* foreach replaces 'ins' with // the instruction record. @@ -4888,7 +4929,7 @@ class MMA_PAT Requires; // Build intrinsic->instruction patterns for all MMA instructions. -foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs) in +foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs, STMATRIXs) in def : MMA_PAT; multiclass MAPA { diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py b/llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py new file mode 100644 index 0000000000000..8f502065345c1 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py @@ -0,0 +1,14 @@ +# Check all variants of instructions supported by PTX78 on SM90 +# RUN: %python %s --ptx=78 --gpu-arch=90 --aa > %t-ptx78-sm_90.ll +# RUN: FileCheck %t-ptx78-sm_90.ll < %t-ptx78-sm_90.ll \ +# RUN: --check-prefixes=PTX78STMATRIX-DAG +# RUN: llc < %t-ptx78-sm_90.ll -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78 \ +# RUN: | FileCheck %t-ptx78-sm_90.ll +# RUN: %if ptxas-12.7 %{ \ +# RUN: llc < %t-ptx78-sm_90.ll -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78 \ +# RUN: | %ptxas-verify -arch=sm_90 \ +# RUN: %} + +import wmma + +wmma.main() diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py index 6ad0a2a5865c4..5c14a54601ed9 100644 --- a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py +++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py @@ -1,9 +1,7 @@ # Check all variants of instructions supported by PTX86 on SM100a # RUN: %python %s --ptx=86 --gpu-arch=100 --aa > %t-ptx86-sm_100a.ll # RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \ -# RUN: --check-prefixes=PTX86LDMATRIX-DAG -# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \ -# RUN: --check-prefixes=PTX86LDMATRIX-DAG +# RUN: --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG # RUN: llc < %t-ptx86-sm_100a.ll -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86 \ # RUN: | FileCheck %t-ptx86-sm_100a.ll # RUN: %if ptxas-12.7 %{ \ diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py index 7d9953484da7d..a77f9adddff9c 100644 --- a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py +++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py @@ -1,9 +1,7 @@ # Check all variants of instructions supported by PTX86 on SM101a # RUN: %python %s --ptx=86 --gpu-arch=101 --aa > %t-ptx86-sm_101a.ll # RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \ -# RUN: --check-prefixes=PTX86LDMATRIX-DAG -# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \ -# RUN: --check-prefixes=PTX86LDMATRIX-DAG +# RUN: --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG # RUN: llc < %t-ptx86-sm_101a.ll -mtriple=nvptx64 -mcpu=sm_101a -mattr=+ptx86 \ # RUN: | FileCheck %t-ptx86-sm_101a.ll # RUN: %if ptxas-12.7 %{ \ diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py index 7bddf0b6fbb78..8126e64d6cc85 100644 --- a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py +++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py @@ -1,9 +1,7 @@ # Check all variants of instructions supported by PTX86 on SM120a # RUN: %python %s --ptx=86 --gpu-arch=120 --aa > %t-ptx86-sm_120a.ll # RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \ -# RUN: --check-prefixes=PTX86LDMATRIX-DAG -# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \ -# RUN: --check-prefixes=PTX86LDMATRIX-DAG +# RUN: --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG # RUN: llc < %t-ptx86-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx86 \ # RUN: | FileCheck %t-ptx86-sm_120a.ll # RUN: %if ptxas-12.7 %{ \ diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py index 2ee489670e9e4..3888e9b6b1b8d 100644 --- a/llvm/test/CodeGen/NVPTX/wmma.py +++ b/llvm/test/CodeGen/NVPTX/wmma.py @@ -10,6 +10,7 @@ from itertools import product from string import Template + class MMAType: def __init__(self, ptx_type): self.ptx_type = ptx_type @@ -176,6 +177,13 @@ def __init__(self, geom, frag, ptx_elt_type): "m8n16:x1:b8x16.b4x16_p64": 1, "m8n16:x2:b8x16.b4x16_p64": 2, "m8n16:x4:b8x16.b4x16_p64": 4, + # stmatrix + "m8n8:x1:b16": 1, + "m8n8:x2:b16": 2, + "m8n8:x4:b16": 4, + "m16n8:x1:b8": 1, + "m16n8:x2:b8": 2, + "m16n8:x4:b8": 4, }.get( "%s:%s:%s" % (geom, frag, ptx_elt_type), { @@ -241,6 +249,13 @@ def make_ldmatrix_ops(geoms, frags, types): ] +def make_stmatrix_ops(geoms, frags, types): + return [ + MMAFrag(geom, frag, ptx_type) + for (geom, frag, ptx_type) in product(geoms, frags, types) + ] + + def get_wmma_ops(): return ( make_mma_ops(["m16n16k8"], ["tf32"], [], ["f32"], []) @@ -315,6 +330,12 @@ def get_ldmatrix_ops(): ) +def get_stmatrix_ops(): + return make_stmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"]) + make_stmatrix_ops( + ["m16n8"], ["x1", "x2", "x4"], ["b8"] + ) + + def is_wmma_geom_supported(geom): # geometries for FP and ints. if geom in ["m8n32k16", "m32n8k16"]: @@ -360,6 +381,14 @@ def is_ldmatrix_geom_supported(geom): assert False # Unexpected geometry. +def is_stmatrix_geom_supported(geom): + if geom in ["m8n8"]: + return ptx_version >= 78 and gpu_arch >= 90 + elif geom in ["m16n8"]: + return ptx_version >= 86 and gpu_arch >= 100 and aa + assert False # Unexpected geometry. + + def is_ldmatrix_trans_supported(geom, trans): if geom in ["m8n8"]: return True @@ -369,6 +398,15 @@ def is_ldmatrix_trans_supported(geom, trans): return trans == "" assert False # Unexpected geometry. + +def is_stmatrix_trans_supported(geom, trans): + if geom in ["m8n8"]: + return True + elif geom in ["m16n8"]: + return trans == ".trans" + assert False # Unexpected geometry. + + def is_type_supported(ptx_type): if ptx_type in ["s8", "u8", "s32"]: return ptx_version >= 63 and gpu_arch >= 72 @@ -463,6 +501,16 @@ def is_ldmatrix_variant_supported(frag, trans): return frag.frag in ["x1", "x2", "x4"] +def is_stmatrix_variant_supported(frag, trans): + if not ( + is_type_supported(frag.mma_type.ptx_type) + and is_stmatrix_geom_supported(frag.geom) + and is_stmatrix_trans_supported(frag.geom, trans) + ): + return False + return frag.frag in ["x1", "x2", "x4"] + + def make_wmma_slice_ty(frag): return [frag.mma_type.llvm_type] * frag.nregs @@ -716,6 +764,61 @@ def gen_ldmatrix_tests(): return generated_items +def gen_stmatrix_tests(): + stmatrix_template = """ +declare void @${intrinsic}(i8 ${as}* %dst, ${args}); + +; CHECK-LABEL: .func {{.*}}test_${function}( +define void @test_${function}(i8 ${as}* %dst, ${args}) { +; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}] +; CHECK: {${check_args}} + call void @${intrinsic}(i8${as}* %dst, ${args}); + ret void +} + +; CHECK-LABEL: .func{{.*}}test_${function}_o( +define void @test_${function}_o(i8 ${as}* %dst, ${args}) { +; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128], +; CHECK: {${check_args}} + %dst1 = getelementptr i8, i8 ${as}* %dst, i32 128; + call void @${intrinsic}(i8 ${as}* %dst1, ${args}); + ret void +} +""" + intrinsic_template = ( + "llvm.nvvm.stmatrix.sync.aligned.${geom}.${frag}${trans}.${itype}.${pspace}" + ) + instruction_template = ("stmatrix.sync.aligned.${geom}.${frag}${trans}${space}.${itype}" + ) + generated_items = [] + + for frag, space, trans in product(get_stmatrix_ops(), + ["", ".shared"], + ["", ".trans"], + ): + if not is_stmatrix_variant_supported(frag, trans): + continue + + params = { + "frag": frag.frag, + "space": space,"trans": trans, + "itype": frag.mma_type.ptx_type, + "pspace": get_pspace(space), + "as": "addrspace(%d)" % get_aspace(space), + "geom": frag.geom, + } + + test_params = params + test_params["intrinsic"] = Template(intrinsic_template).substitute(params) + test_params["function"] = test_params["intrinsic"].replace(".", "_") + test_params["instruction"] = Template(instruction_template).substitute(params) + test_params["args"] = make_wmma_slice_args(frag) + test_params["check_args"] = check_pattern(frag) + + print(Template(stmatrix_template).substitute(test_params)) + generated_items.append((test_params["intrinsic"], test_params["instruction"])) + + return generated_items def mma_signature(op): if op.a.mma_type.ptx_type == "f16": @@ -893,6 +996,7 @@ def gen_check_unsupported_ops(items): ; NOALTFLOAT-NOT: .{{bf16|tf32}} ; NODOUBLE-NOT: .f64 ; NOLDMATRIX-NOT: ldmatrix.sync.aligned +; NOSTMATRIX-NOT: stmatrix.sync.aligned ; M16N16-DAG: m16n16k16.load.{{[ab].*}}.f16.p ; M16N16-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p @@ -994,6 +1098,26 @@ def gen_check_unsupported_ops(items): ; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32 ; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.trans.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.trans.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.trans.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.shared.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.shared.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.shared.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.trans.shared.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.trans.shared.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.trans.shared.b16 + +; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x1.trans.b8 +; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x2.trans.b8 +; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x4.trans.b8 +; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x1.trans.shared.b8 +; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x2.trans.shared.b8 +; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x4.trans.shared.b8 + ; PTX71MMA-DAG: mma.m8n8k4.row.col.f64 ; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32 ; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32 @@ -1039,6 +1163,7 @@ def gen_tests(): items = gen_wmma_load_tests() items += gen_wmma_store_tests() items += gen_ldmatrix_tests() + items += gen_stmatrix_tests() items += gen_wmma_mma_tests() items += gen_mma_tests() gen_check_unsupported_ops(items) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 45a8904375e2b..8de5932aaf2e3 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1990,10 +1990,22 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">, let hasVerifier = 1; } -def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">, - Arguments<(ins LLVM_PointerShared:$ptr, - Variadic:$sources, - MMALayoutAttr:$layout)> { +def LdStMatrixShapeM8N8 : I32EnumAttrCase<"M8N8", 0, "m8n8">; +def LdStMatrixShapeM8N16 : I32EnumAttrCase<"M8N16", 1, "m8n16">; +def LdStMatrixShapeM16N8 : I32EnumAttrCase<"M16N8", 2, "m16n8">; +def LdStMatrixShapeM16N16 : I32EnumAttrCase<"M16N16", 3, "m16n16">; + +def LdStMatrixShape : I32EnumAttr<"LdStMatrixShape", "Matrix shape for ldmatrix and stmatrix", + [LdStMatrixShapeM8N8, LdStMatrixShapeM16N8, LdStMatrixShapeM16N16]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::NVVM"; +} +def LdStMatrixShapeAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def NVVM_StMatrixOp: NVVM_Op<"stmatrix">, + Arguments<(ins LLVM_AnyPointer: $ptr, Variadic:$sources, MMALayoutAttr:$layout, LdStMatrixShapeAttr:$shape)> { let summary = "cooperative matrix store"; let description = [{ Collectively store one or more matrices across all threads in a warp to the @@ -2001,21 +2013,12 @@ def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">, [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix) }]; - - let assemblyFormat = "$ptr `,` $sources attr-dict `:` type(operands)"; - let extraClassDefinition = [{ - std::string $cppClass::getPtx() { - int d = getSources().size(); - std::string ptx = "stmatrix.sync.aligned"; - ptx += ".x" + std::to_string(d); - if (getLayout() == NVVM::MMALayout::col) - ptx += ".trans"; - if(d == 1) ptx += ".m8n8.shared.b16 [%0], {%1};"; - if(d == 2) ptx += ".m8n8.shared.b16 [%0], {%1, %2};"; - if(d == 4) ptx += ".m8n8.shared.b16 [%0], {%1, %2, %3, %4};"; - return ptx; - } + string llvmBuilder = [{ + auto operands = moduleTranslation.lookupValues(opInst.getOperands()); + auto intId = getStMatrixIntrinsicId($layout, $sources.size(), $shape); + createIntrinsicCall(builder, intId, operands, operands[0]->getType()); }]; + let assemblyFormat = "$ptr `,` $sources attr-dict `:` type(operands)"; let hasVerifier = 1; } diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp index eecca64c4bf81..d03242f402ec5 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -163,6 +163,49 @@ static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout, } } +/// Return the intrinsic ID associated with stmatrix for the given paramters. +static llvm::Intrinsic::ID getStMatrixIntrinsicId(NVVM::MMALayout layout, + int32_t num, + NVVM::LdStMatrixShape shape) { + if (shape == NVVM::LdStMatrixShape::M8N8) { + if (layout == NVVM::MMALayout::row) { + switch (num) { + case 1: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16; + case 2: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16; + case 4: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16; + default: + llvm_unreachable("unsupported number of matrix"); + } + } else { + switch (num) { + case 1: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16; + case 2: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16; + case 4: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16; + default: + llvm_unreachable("unsupported number of matrix"); + } + } + } else { + // for 16x8 matrices, .trans is mandatory + switch (num) { + case 1: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8; + case 2: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8; + case 4: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8; + default: + llvm_unreachable("unsupported number of matrix"); + } + } +} + /// Return the intrinsic ID associated with st.bulk for the given address type. static llvm::Intrinsic::ID getStBulkIntrinsicId(LLVM::LLVMPointerType addrType) { diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir index 8d720ce62a91b..580b09d70c480 100644 --- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir +++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir @@ -580,30 +580,6 @@ func.func @elect_one_leader_sync() { // ----- -// CHECK-LABEL: @stmatrix( -// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !llvm.ptr<3>, -// CHECK-SAME: %[[arg1:[a-zA-Z0-9_]+]]: i32, -// CHECK-SAME: %[[arg2:[a-zA-Z0-9_]+]]: i32, -// CHECK-SAME: %[[arg3:[a-zA-Z0-9_]+]]: i32, -// CHECK-SAME: %[[arg4:[a-zA-Z0-9_]+]]: i32) -llvm.func @stmatrix(%arg0 : !llvm.ptr<3>, %m1 : i32, %m2 : i32, %m3 : i32, %m4 : i32) { -// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.m8n8.shared.b16 [$0], {$1};", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> () -// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.m8n8.shared.b16 [$0], {$1, $2};", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> () -// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> () -// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [$0], {$1};", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> () -// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [$0], {$1, $2};", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> () -// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> () - nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout} : !llvm.ptr<3>, i32 - nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout} : !llvm.ptr<3>, i32, i32 - nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout} : !llvm.ptr<3>, i32, i32, i32, i32 - nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout} : !llvm.ptr<3>, i32 - nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout} : !llvm.ptr<3>, i32, i32 - nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout} : !llvm.ptr<3>, i32, i32, i32, i32 - llvm.return -} - -// ----- - // CHECK-LABEL: @init_mbarrier_arrive_expect_tx llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) { //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "prefetch.tensormap [$0];", "l" diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index f86a04186f512..3be35faf091e2 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -573,6 +573,29 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { llvm.return } +// CHECK-LABEL: @st_matrix +llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape} : !llvm.ptr<3>, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape} : !llvm.ptr<3>, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x1.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape} : !llvm.ptr<3>, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x2.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape} : !llvm.ptr<3>, i32, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x2.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape} : !llvm.ptr<3>, i32, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x2.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape} : !llvm.ptr<3>, i32, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x4.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape} : !llvm.ptr<3>, i32, i32, i32, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x4.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape} : !llvm.ptr<3>, i32, i32, i32, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x4.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape} : !llvm.ptr<3>, i32, i32, i32, i32 + llvm.return +} + // This function has the "kernel" attribute attached and should appear in the // NVVM annotations after conversion. llvm.func @kernel_func() attributes {nvvm.kernel} { From 653ae854e5d88f62b7e2e2353f8bb385251294eb Mon Sep 17 00:00:00 2001 From: Gao Yanfeng Date: Mon, 14 Jul 2025 10:40:47 +0800 Subject: [PATCH 2/4] Remove changes on NVPTX --- llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 29 +---- llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 45 +------ llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py | 14 --- llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py | 4 +- llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py | 4 +- llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py | 4 +- llvm/test/CodeGen/NVPTX/wmma.py | 125 ------------------- 7 files changed, 12 insertions(+), 213 deletions(-) delete mode 100644 llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index d94be492b0c02..3d010e04824c5 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -3952,10 +3952,7 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic( case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col: case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col_stride: case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row: - case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride: - case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16: - case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16: - case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8: { + case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride: { Info.opc = ISD::INTRINSIC_VOID; Info.memVT = MVT::v2i32; Info.ptrVal = I.getArgOperand(0); @@ -3978,30 +3975,6 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic( return true; } - case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16: - case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16: - case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8: { - Info.opc = ISD::INTRINSIC_VOID; - Info.memVT = MVT::i32; - Info.ptrVal = I.getArgOperand(0); - Info.offset = 0; - Info.flags = MachineMemOperand::MOStore; - Info.align = Align(4); - return true; - } - - case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16: - case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16: - case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8: { - Info.opc = ISD::INTRINSIC_VOID; - Info.memVT = MVT::v4i32; - Info.ptrVal = I.getArgOperand(0); - Info.offset = 0; - Info.flags = MachineMemOperand::MOStore; - Info.align = Align(16); - return true; - } - case Intrinsic::nvvm_atomic_add_gen_f_cta: case Intrinsic::nvvm_atomic_add_gen_f_sys: case Intrinsic::nvvm_atomic_add_gen_i_cta: diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index 1e24bf8ab99e1..93827be5c2811 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -4597,14 +4597,7 @@ class WMMA_REGINFO !and(!eq(op, "ldmatrix"), !eq(ptx_elt_type, "b8x16.b4x16_p64"), - !eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>], - - !and(!eq(op, "stmatrix"),!eq(ptx_elt_type, "b16"), - !eq(geom, "m8n8")) : [hasSM<90>, hasPTX<78>], - - !and(!eq(op, "stmatrix"), - !eq(ptx_elt_type, "b8"), - !eq(geom, "m16n8")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]); + !eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]); // template DAGs for instruction inputs/output. dag Outs = !dag(outs, ptx_regs, reg_names); @@ -4885,40 +4878,6 @@ defset list LDMATRIXs = { } // transposed } // defset -// -// stmatrix.sync.aligned.m8n8[|.trans][|.shared].b16 -// -class STMATRIX - : WMMA_INSTR.record, [!con((ins ADDR:$dst), Frag.Ins)]>, - Requires { - // Build PatFrag that only matches particular address space. - dag PFOperands = !con((ops node:$dst), !dag(ops, !listsplat(node, !size(Frag.regs)), Frag.reg_names)); - PatFrag IntrFrag = PatFrag; - // Build AS-constrained pattern. - let IntrinsicPattern = BuildPatternPF.ret; - let OutOperandList = (outs); - let InOperandList = !con(Args, (ins MmaCode:$ptx)); - let AsmString = "stmatrix.sync.aligned." - # Frag.geom - # "." # Frag.frag - # !if(Transposed, ".trans", "") - # Space - # "." # Frag.ptx_elt_type - # " [$dst], " # Frag.regstring # ";"; -} - -// Create all stmatrix variants -defset list STMATRIXs = { - foreach transposed = [false, true] in {foreach space = [".shared", ""] in { - foreach frag = NVVM_MMA_OPS.all_stmatrix_ops in - if NVVM_STMATRIX_SUPPORTED.ret then - def : STMATRIX, transposed, space>; - } // space - } // transposed -} // defset - // Constructing non-flat DAGs is still a pain. I can't !subst a dag node with a // dag, so the ptx.version must be appended *after* foreach replaces 'ins' with // the instruction record. @@ -4929,7 +4888,7 @@ class MMA_PAT Requires; // Build intrinsic->instruction patterns for all MMA instructions. -foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs, STMATRIXs) in +foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs) in def : MMA_PAT; multiclass MAPA { diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py b/llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py deleted file mode 100644 index 8f502065345c1..0000000000000 --- a/llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py +++ /dev/null @@ -1,14 +0,0 @@ -# Check all variants of instructions supported by PTX78 on SM90 -# RUN: %python %s --ptx=78 --gpu-arch=90 --aa > %t-ptx78-sm_90.ll -# RUN: FileCheck %t-ptx78-sm_90.ll < %t-ptx78-sm_90.ll \ -# RUN: --check-prefixes=PTX78STMATRIX-DAG -# RUN: llc < %t-ptx78-sm_90.ll -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78 \ -# RUN: | FileCheck %t-ptx78-sm_90.ll -# RUN: %if ptxas-12.7 %{ \ -# RUN: llc < %t-ptx78-sm_90.ll -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78 \ -# RUN: | %ptxas-verify -arch=sm_90 \ -# RUN: %} - -import wmma - -wmma.main() diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py index 5c14a54601ed9..6ad0a2a5865c4 100644 --- a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py +++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py @@ -1,7 +1,9 @@ # Check all variants of instructions supported by PTX86 on SM100a # RUN: %python %s --ptx=86 --gpu-arch=100 --aa > %t-ptx86-sm_100a.ll # RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \ -# RUN: --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG +# RUN: --check-prefixes=PTX86LDMATRIX-DAG +# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \ +# RUN: --check-prefixes=PTX86LDMATRIX-DAG # RUN: llc < %t-ptx86-sm_100a.ll -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86 \ # RUN: | FileCheck %t-ptx86-sm_100a.ll # RUN: %if ptxas-12.7 %{ \ diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py index a77f9adddff9c..7d9953484da7d 100644 --- a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py +++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py @@ -1,7 +1,9 @@ # Check all variants of instructions supported by PTX86 on SM101a # RUN: %python %s --ptx=86 --gpu-arch=101 --aa > %t-ptx86-sm_101a.ll # RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \ -# RUN: --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG +# RUN: --check-prefixes=PTX86LDMATRIX-DAG +# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \ +# RUN: --check-prefixes=PTX86LDMATRIX-DAG # RUN: llc < %t-ptx86-sm_101a.ll -mtriple=nvptx64 -mcpu=sm_101a -mattr=+ptx86 \ # RUN: | FileCheck %t-ptx86-sm_101a.ll # RUN: %if ptxas-12.7 %{ \ diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py index 8126e64d6cc85..7bddf0b6fbb78 100644 --- a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py +++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py @@ -1,7 +1,9 @@ # Check all variants of instructions supported by PTX86 on SM120a # RUN: %python %s --ptx=86 --gpu-arch=120 --aa > %t-ptx86-sm_120a.ll # RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \ -# RUN: --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG +# RUN: --check-prefixes=PTX86LDMATRIX-DAG +# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \ +# RUN: --check-prefixes=PTX86LDMATRIX-DAG # RUN: llc < %t-ptx86-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx86 \ # RUN: | FileCheck %t-ptx86-sm_120a.ll # RUN: %if ptxas-12.7 %{ \ diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py index 3888e9b6b1b8d..2ee489670e9e4 100644 --- a/llvm/test/CodeGen/NVPTX/wmma.py +++ b/llvm/test/CodeGen/NVPTX/wmma.py @@ -10,7 +10,6 @@ from itertools import product from string import Template - class MMAType: def __init__(self, ptx_type): self.ptx_type = ptx_type @@ -177,13 +176,6 @@ def __init__(self, geom, frag, ptx_elt_type): "m8n16:x1:b8x16.b4x16_p64": 1, "m8n16:x2:b8x16.b4x16_p64": 2, "m8n16:x4:b8x16.b4x16_p64": 4, - # stmatrix - "m8n8:x1:b16": 1, - "m8n8:x2:b16": 2, - "m8n8:x4:b16": 4, - "m16n8:x1:b8": 1, - "m16n8:x2:b8": 2, - "m16n8:x4:b8": 4, }.get( "%s:%s:%s" % (geom, frag, ptx_elt_type), { @@ -249,13 +241,6 @@ def make_ldmatrix_ops(geoms, frags, types): ] -def make_stmatrix_ops(geoms, frags, types): - return [ - MMAFrag(geom, frag, ptx_type) - for (geom, frag, ptx_type) in product(geoms, frags, types) - ] - - def get_wmma_ops(): return ( make_mma_ops(["m16n16k8"], ["tf32"], [], ["f32"], []) @@ -330,12 +315,6 @@ def get_ldmatrix_ops(): ) -def get_stmatrix_ops(): - return make_stmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"]) + make_stmatrix_ops( - ["m16n8"], ["x1", "x2", "x4"], ["b8"] - ) - - def is_wmma_geom_supported(geom): # geometries for FP and ints. if geom in ["m8n32k16", "m32n8k16"]: @@ -381,14 +360,6 @@ def is_ldmatrix_geom_supported(geom): assert False # Unexpected geometry. -def is_stmatrix_geom_supported(geom): - if geom in ["m8n8"]: - return ptx_version >= 78 and gpu_arch >= 90 - elif geom in ["m16n8"]: - return ptx_version >= 86 and gpu_arch >= 100 and aa - assert False # Unexpected geometry. - - def is_ldmatrix_trans_supported(geom, trans): if geom in ["m8n8"]: return True @@ -398,15 +369,6 @@ def is_ldmatrix_trans_supported(geom, trans): return trans == "" assert False # Unexpected geometry. - -def is_stmatrix_trans_supported(geom, trans): - if geom in ["m8n8"]: - return True - elif geom in ["m16n8"]: - return trans == ".trans" - assert False # Unexpected geometry. - - def is_type_supported(ptx_type): if ptx_type in ["s8", "u8", "s32"]: return ptx_version >= 63 and gpu_arch >= 72 @@ -501,16 +463,6 @@ def is_ldmatrix_variant_supported(frag, trans): return frag.frag in ["x1", "x2", "x4"] -def is_stmatrix_variant_supported(frag, trans): - if not ( - is_type_supported(frag.mma_type.ptx_type) - and is_stmatrix_geom_supported(frag.geom) - and is_stmatrix_trans_supported(frag.geom, trans) - ): - return False - return frag.frag in ["x1", "x2", "x4"] - - def make_wmma_slice_ty(frag): return [frag.mma_type.llvm_type] * frag.nregs @@ -764,61 +716,6 @@ def gen_ldmatrix_tests(): return generated_items -def gen_stmatrix_tests(): - stmatrix_template = """ -declare void @${intrinsic}(i8 ${as}* %dst, ${args}); - -; CHECK-LABEL: .func {{.*}}test_${function}( -define void @test_${function}(i8 ${as}* %dst, ${args}) { -; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}] -; CHECK: {${check_args}} - call void @${intrinsic}(i8${as}* %dst, ${args}); - ret void -} - -; CHECK-LABEL: .func{{.*}}test_${function}_o( -define void @test_${function}_o(i8 ${as}* %dst, ${args}) { -; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128], -; CHECK: {${check_args}} - %dst1 = getelementptr i8, i8 ${as}* %dst, i32 128; - call void @${intrinsic}(i8 ${as}* %dst1, ${args}); - ret void -} -""" - intrinsic_template = ( - "llvm.nvvm.stmatrix.sync.aligned.${geom}.${frag}${trans}.${itype}.${pspace}" - ) - instruction_template = ("stmatrix.sync.aligned.${geom}.${frag}${trans}${space}.${itype}" - ) - generated_items = [] - - for frag, space, trans in product(get_stmatrix_ops(), - ["", ".shared"], - ["", ".trans"], - ): - if not is_stmatrix_variant_supported(frag, trans): - continue - - params = { - "frag": frag.frag, - "space": space,"trans": trans, - "itype": frag.mma_type.ptx_type, - "pspace": get_pspace(space), - "as": "addrspace(%d)" % get_aspace(space), - "geom": frag.geom, - } - - test_params = params - test_params["intrinsic"] = Template(intrinsic_template).substitute(params) - test_params["function"] = test_params["intrinsic"].replace(".", "_") - test_params["instruction"] = Template(instruction_template).substitute(params) - test_params["args"] = make_wmma_slice_args(frag) - test_params["check_args"] = check_pattern(frag) - - print(Template(stmatrix_template).substitute(test_params)) - generated_items.append((test_params["intrinsic"], test_params["instruction"])) - - return generated_items def mma_signature(op): if op.a.mma_type.ptx_type == "f16": @@ -996,7 +893,6 @@ def gen_check_unsupported_ops(items): ; NOALTFLOAT-NOT: .{{bf16|tf32}} ; NODOUBLE-NOT: .f64 ; NOLDMATRIX-NOT: ldmatrix.sync.aligned -; NOSTMATRIX-NOT: stmatrix.sync.aligned ; M16N16-DAG: m16n16k16.load.{{[ab].*}}.f16.p ; M16N16-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p @@ -1098,26 +994,6 @@ def gen_check_unsupported_ops(items): ; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32 ; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64 -; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.b16 -; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.b16 -; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.b16 -; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.trans.b16 -; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.trans.b16 -; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.trans.b16 -; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.shared.b16 -; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.shared.b16 -; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.shared.b16 -; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.trans.shared.b16 -; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.trans.shared.b16 -; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.trans.shared.b16 - -; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x1.trans.b8 -; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x2.trans.b8 -; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x4.trans.b8 -; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x1.trans.shared.b8 -; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x2.trans.shared.b8 -; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x4.trans.shared.b8 - ; PTX71MMA-DAG: mma.m8n8k4.row.col.f64 ; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32 ; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32 @@ -1163,7 +1039,6 @@ def gen_tests(): items = gen_wmma_load_tests() items += gen_wmma_store_tests() items += gen_ldmatrix_tests() - items += gen_stmatrix_tests() items += gen_wmma_mma_tests() items += gen_mma_tests() gen_check_unsupported_ops(items) From 56db71e89825f6d727f14e7bdced49019fb63380 Mon Sep 17 00:00:00 2001 From: Gao Yanfeng Date: Wed, 16 Jul 2025 10:20:09 +0800 Subject: [PATCH 3/4] Move the changes of IntrinsicsNVVM.td to another PR --- llvm/include/llvm/IR/IntrinsicsNVVM.td | 65 -------------------------- 1 file changed, 65 deletions(-) diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td index aad21fd4cba1c..0375f29ad8906 100644 --- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -331,11 +331,6 @@ class WMMA_REGS { !eq(gf,"m8n16:x2") : !listsplat(llvm_i32_ty, 2), !eq(gf,"m8n16:x4") : !listsplat(llvm_i32_ty, 4), - // stmatrix b8 -> s32 @ m16n8 - !eq(gf,"m16n8:x1") : !listsplat(llvm_i32_ty, 1), - !eq(gf,"m16n8:x2") : !listsplat(llvm_i32_ty, 2), - !eq(gf,"m16n8:x4") : !listsplat(llvm_i32_ty, 4), - ); } @@ -408,17 +403,6 @@ class LDMATRIX_NAME { !subst("llvm.", "int_", intr)); } -class STMATRIX_NAME { - string intr = "llvm.nvvm.stmatrix.sync.aligned" - # "." # Frag.geom - # "." # Frag.frag - # !if(Trans, ".trans", "") - # "." # Frag.ptx_elt_type - ; - string record = !subst(".", "_", - !subst("llvm.", "int_", intr)); -} - // Generates list of 4-tuples of WMMA_REGS representing a valid MMA op. // Geom: list of supported geometries. // TypeN: PTX type of the corresponding fragment's element. @@ -459,16 +443,6 @@ class LDMATRIX_OPS Geom, list Frags, list Types> { list ops = !foreach(x, ret, x.gft); } -class STMATRIX_OPS Geom, list Frags, list Types> { - list ret = - !foldl([], Geom, t1, geom, !listconcat(t1, - !foldl([], Frags, t2, frag, !listconcat(t2, - !foldl([], Types, t3, type, !listconcat(t3, - [WMMA_REGS])))))); - // Debugging aid for readable representation of the list above. - list ops = !foreach(x, ret, x.gft); -} - // Creates list of valid combinations of fragments. This is the main list that // drives generation of corresponding intrinsics and instructions. class NVVM_MMA_OPS { @@ -563,18 +537,9 @@ class NVVM_MMA_OPS { list ldmatrix_geom_m8n16_ops = LDMATRIX_OPS< ["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret; - list stmatrix_b16_ops = STMATRIX_OPS< - ["m8n8"], ["x1", "x2", "x4"], ["b16"]>.ret; - - list stmatrix_b8_ops = STMATRIX_OPS< - ["m16n8"], ["x1", "x2", "x4"], ["b8"]>.ret; - list all_ldmatrix_ops = !listconcat(ldmatrix_b16_ops, ldmatrix_geom_m16n16_ops, ldmatrix_geom_m8n16_ops); - - list all_stmatrix_ops = !listconcat(stmatrix_b16_ops, - stmatrix_b8_ops); } def NVVM_MMA_OPS : NVVM_MMA_OPS; @@ -715,19 +680,6 @@ class NVVM_LDMATRIX_SUPPORTED { ); } -// Returns true if the fragment is valid for stmatrix ops is supported; -// false otherwise. -class NVVM_STMATRIX_SUPPORTED { - string g = frag.geom; - string t = frag.ptx_elt_type; - - bit ret = !cond( - !and(!eq(g, "m8n8"), !eq(t, "b16")): true, - !and(!eq(g, "m16n8"), !eq(t, "b8"), !eq(trans, 1)): true, - true: false - ); -} - class SHFL_INFO { string Suffix = !if(sync, "sync_", "") # mode # "_" @@ -2017,23 +1969,6 @@ foreach transposed = [0, 1] in { } } -// STMATRIX -class NVVM_STMATRIX - : Intrinsic<[], - !listconcat([llvm_anyptr_ty], Frag.regs), - [IntrWriteMem, IntrArgMemOnly, IntrNoCallback, - WriteOnly>, NoCapture>], - STMATRIX_NAME.intr>; - -foreach transposed = [0, 1] in { - foreach frag = NVVM_MMA_OPS.all_stmatrix_ops in { - if NVVM_STMATRIX_SUPPORTED.ret then { - def STMATRIX_NAME.record - : NVVM_STMATRIX; - } - } -} - // MAPA let IntrProperties = [IntrNoMem, IntrSpeculatable, NoCapture>] in { def int_nvvm_mapa From e5a277a3ad4735cc67ecd82ccf4597dbddde355f Mon Sep 17 00:00:00 2001 From: Gao Yanfeng Date: Wed, 16 Jul 2025 14:22:56 +0800 Subject: [PATCH 4/4] Modify the arguments of the stmatrix op --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 26 ++++--- .../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp | 73 ++++++++++--------- mlir/test/Target/LLVMIR/nvvmir.mlir | 18 ++--- 3 files changed, 63 insertions(+), 54 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 8de5932aaf2e3..af9c29274dc1d 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1990,22 +1990,30 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">, let hasVerifier = 1; } -def LdStMatrixShapeM8N8 : I32EnumAttrCase<"M8N8", 0, "m8n8">; -def LdStMatrixShapeM8N16 : I32EnumAttrCase<"M8N16", 1, "m8n16">; -def LdStMatrixShapeM16N8 : I32EnumAttrCase<"M16N8", 2, "m16n8">; -def LdStMatrixShapeM16N16 : I32EnumAttrCase<"M16N16", 3, "m16n16">; +def LdStMatrixShapeAttr : NVVM_Attr<"LdStMatrixShape", "ld_st_matrix_shape"> { + let summary = "Matrix shape for ldmatrix and stmatrix"; + let parameters = (ins "int":$m, "int":$n); + let assemblyFormat = "`<` struct(params) `>`"; +} + +def LdStMatrixEltTypeB16 : I32EnumAttrCase<"B16", 0, "b16">; +def LdStMatrixEltTypeB8 : I32EnumAttrCase<"B8", 1, "b8">; +def LdStMatrixEltTypeB8X16_B6X16_P32 : I32EnumAttrCase<"B8X16_B6X16_P32", 2, "b8x16.b6x16_p32">; +def LdStMatrixEltTypeB8X16_B4X16_P64 : I32EnumAttrCase<"B8X16_B4X16_P64", 3, "b8x16.b4x16_p64">; -def LdStMatrixShape : I32EnumAttr<"LdStMatrixShape", "Matrix shape for ldmatrix and stmatrix", - [LdStMatrixShapeM8N8, LdStMatrixShapeM16N8, LdStMatrixShapeM16N16]> { +def LdStMatrixEltType : I32EnumAttr<"LdStMatrixEltType", "Element type for ldmatrix and stmatrix", + [LdStMatrixEltTypeB16, LdStMatrixEltTypeB8, + LdStMatrixEltTypeB8X16_B6X16_P32, LdStMatrixEltTypeB8X16_B4X16_P64]> { let genSpecializedAttr = 0; let cppNamespace = "::mlir::NVVM"; } -def LdStMatrixShapeAttr : EnumAttr { +def LdStMatrixEltTypeAttr : EnumAttr { let assemblyFormat = "`<` $value `>`"; } def NVVM_StMatrixOp: NVVM_Op<"stmatrix">, - Arguments<(ins LLVM_AnyPointer: $ptr, Variadic:$sources, MMALayoutAttr:$layout, LdStMatrixShapeAttr:$shape)> { + Arguments<(ins LLVM_AnyPointer: $ptr, Variadic:$sources, MMALayoutAttr:$layout, + LdStMatrixShapeAttr:$shape, LdStMatrixEltTypeAttr:$elttype)> { let summary = "cooperative matrix store"; let description = [{ Collectively store one or more matrices across all threads in a warp to the @@ -2015,7 +2023,7 @@ def NVVM_StMatrixOp: NVVM_Op<"stmatrix">, }]; string llvmBuilder = [{ auto operands = moduleTranslation.lookupValues(opInst.getOperands()); - auto intId = getStMatrixIntrinsicId($layout, $sources.size(), $shape); + auto intId = getStMatrixIntrinsicId($layout, $sources.size(), $shape, $elttype); createIntrinsicCall(builder, intId, operands, operands[0]->getType()); }]; let assemblyFormat = "$ptr `,` $sources attr-dict `:` type(operands)"; diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp index d03242f402ec5..3491f658529c2 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -164,46 +164,47 @@ static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout, } /// Return the intrinsic ID associated with stmatrix for the given paramters. -static llvm::Intrinsic::ID getStMatrixIntrinsicId(NVVM::MMALayout layout, - int32_t num, - NVVM::LdStMatrixShape shape) { - if (shape == NVVM::LdStMatrixShape::M8N8) { - if (layout == NVVM::MMALayout::row) { - switch (num) { - case 1: - return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16; - case 2: - return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16; - case 4: - return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16; - default: - llvm_unreachable("unsupported number of matrix"); - } - } else { - switch (num) { - case 1: - return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16; - case 2: - return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16; - case 4: - return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16; - default: - llvm_unreachable("unsupported number of matrix"); +static llvm::Intrinsic::ID +getStMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, + NVVM::LdStMatrixShapeAttr shape, + NVVM::LdStMatrixEltType eltType) { + if (shape.getM() == 8 && shape.getN() == 8) { + if (eltType == NVVM::LdStMatrixEltType::B16) { + if (layout == NVVM::MMALayout::row) { + switch (num) { + case 1: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16; + case 2: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16; + case 4: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16; + } + } else { + switch (num) { + case 1: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16; + case 2: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16; + case 4: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16; + } } } - } else { - // for 16x8 matrices, .trans is mandatory - switch (num) { - case 1: - return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8; - case 2: - return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8; - case 4: - return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8; - default: - llvm_unreachable("unsupported number of matrix"); + } else if (shape.getM() == 16 && shape.getN() == 8) { + if (eltType == NVVM::LdStMatrixEltType::B8) { + if (layout == NVVM::MMALayout::col) { + switch (num) { + case 1: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8; + case 2: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8; + case 4: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8; + } + } } } + llvm_unreachable("unknown stmatrix kind"); } /// Return the intrinsic ID associated with st.bulk for the given address type. diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index 3be35faf091e2..ad3e67b039d8f 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -576,23 +576,23 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { // CHECK-LABEL: @st_matrix llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}) - nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape} : !llvm.ptr<3>, i32 + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32 // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}) - nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape} : !llvm.ptr<3>, i32 + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32 // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x1.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}) - nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape} : !llvm.ptr<3>, i32 + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32 // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x2.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) - nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape} : !llvm.ptr<3>, i32, i32 + nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32, i32 // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x2.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) - nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape} : !llvm.ptr<3>, i32, i32 + nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32, i32 // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x2.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) - nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape} : !llvm.ptr<3>, i32, i32 + nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32, i32 // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x4.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) - nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape} : !llvm.ptr<3>, i32, i32, i32, i32 + nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32, i32, i32, i32 // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x4.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) - nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape} : !llvm.ptr<3>, i32, i32, i32, i32 + nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32, i32, i32, i32 // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x4.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) - nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape} : !llvm.ptr<3>, i32, i32, i32, i32 + nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : !llvm.ptr<3>, i32, i32, i32, i32 llvm.return }