Skip to content

[MLIR][NVVM][NVPTX] Support intrinsics about stmatrix #148377

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Pecco-314
Copy link

Add support for the @llvm.nvvm.stmatrix intrinsic series. These correspond to PTX stmatrix operations, as documented in the PTX ISA reference.

@Pecco-314 Pecco-314 requested a review from grypp as a code owner July 12, 2025 15:21
Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Jul 12, 2025

@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-backend-nvptx

Author: Pecco (Pecco-314)

Changes

Add support for the @<!-- -->llvm.nvvm.stmatrix intrinsic series. These correspond to PTX stmatrix operations, as documented in the PTX ISA reference.


Patch is 29.97 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/148377.diff

12 Files Affected:

  • (modified) llvm/include/llvm/IR/IntrinsicsNVVM.td (+65)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+28-1)
  • (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+43-2)
  • (added) llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py (+14)
  • (modified) llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py (+1-3)
  • (modified) llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py (+1-3)
  • (modified) llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py (+1-3)
  • (modified) llvm/test/CodeGen/NVPTX/wmma.py (+125)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+21-18)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp (+43)
  • (modified) mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir (-24)
  • (modified) mlir/test/Target/LLVMIR/nvvmir.mlir (+23)
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 0375f29ad8906..aad21fd4cba1c 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -331,6 +331,11 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
     !eq(gf,"m8n16:x2") : !listsplat(llvm_i32_ty, 2),
     !eq(gf,"m8n16:x4") : !listsplat(llvm_i32_ty, 4),
 
+    // stmatrix b8 -> s32 @ m16n8
+    !eq(gf,"m16n8:x1") : !listsplat(llvm_i32_ty, 1),
+    !eq(gf,"m16n8:x2") : !listsplat(llvm_i32_ty, 2),
+    !eq(gf,"m16n8:x4") : !listsplat(llvm_i32_ty, 4),
+
   );
 }
 
@@ -403,6 +408,17 @@ class LDMATRIX_NAME<WMMA_REGS Frag, int Trans> {
                   !subst("llvm.", "int_", intr));
 }
 
+class STMATRIX_NAME<WMMA_REGS Frag, int Trans> {
+  string intr = "llvm.nvvm.stmatrix.sync.aligned"
+                # "." # Frag.geom
+                # "." # Frag.frag
+                # !if(Trans, ".trans", "")
+                # "." # Frag.ptx_elt_type
+                ;
+  string record = !subst(".", "_",
+                  !subst("llvm.", "int_", intr));
+}
+
 // Generates list of 4-tuples of WMMA_REGS representing a valid MMA op.
 //   Geom: list of supported geometries.
 //   TypeN: PTX type of the corresponding fragment's element.
@@ -443,6 +459,16 @@ class LDMATRIX_OPS<list<string> Geom, list<string> Frags, list<string> Types> {
    list<string> ops = !foreach(x, ret, x.gft);
 }
 
+class STMATRIX_OPS<list<string> Geom, list<string> Frags, list<string> Types> {
+  list<WMMA_REGS> ret =
+     !foldl([]<WMMA_REGS>, Geom, t1, geom, !listconcat(t1,
+     !foldl([]<WMMA_REGS>, Frags, t2, frag, !listconcat(t2,
+     !foldl([]<WMMA_REGS>, Types, t3, type, !listconcat(t3,
+            [WMMA_REGS<geom, frag, type>]))))));
+   // Debugging aid for readable representation of the list above.
+   list<string> ops = !foreach(x, ret, x.gft);
+}
+
 // Creates list of valid combinations of fragments. This is the main list that
 // drives generation of corresponding intrinsics and instructions.
 class NVVM_MMA_OPS {
@@ -537,9 +563,18 @@ class NVVM_MMA_OPS {
   list<WMMA_REGS> ldmatrix_geom_m8n16_ops = LDMATRIX_OPS<
     ["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret;
 
+  list<WMMA_REGS> stmatrix_b16_ops = STMATRIX_OPS<
+    ["m8n8"], ["x1", "x2", "x4"], ["b16"]>.ret;
+
+  list<WMMA_REGS> stmatrix_b8_ops = STMATRIX_OPS<
+    ["m16n8"], ["x1", "x2", "x4"], ["b8"]>.ret;
+
   list<WMMA_REGS> all_ldmatrix_ops = !listconcat(ldmatrix_b16_ops,
                                                  ldmatrix_geom_m16n16_ops,
                                                  ldmatrix_geom_m8n16_ops);
+
+  list<WMMA_REGS> all_stmatrix_ops = !listconcat(stmatrix_b16_ops,
+                                                 stmatrix_b8_ops);
 }
 
 def NVVM_MMA_OPS : NVVM_MMA_OPS;
@@ -680,6 +715,19 @@ class NVVM_LDMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> {
   );
 }
 
+// Returns true if the fragment is valid for stmatrix ops is supported;
+// false otherwise.
+class NVVM_STMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> {
+  string g = frag.geom;
+  string t = frag.ptx_elt_type;
+
+  bit ret = !cond(
+    !and(!eq(g, "m8n8"), !eq(t, "b16")): true,
+    !and(!eq(g, "m16n8"), !eq(t, "b8"), !eq(trans, 1)): true,
+    true: false
+  );
+}
+
 class SHFL_INFO<bit sync, string mode, string type, bit return_pred> {
   string Suffix = !if(sync, "sync_", "")
                   # mode # "_"
@@ -1969,6 +2017,23 @@ foreach transposed = [0, 1] in {
   }
 }
 
+// STMATRIX
+class NVVM_STMATRIX<WMMA_REGS Frag, int Transposed>
+  : Intrinsic<[],
+          !listconcat([llvm_anyptr_ty], Frag.regs),
+          [IntrWriteMem, IntrArgMemOnly, IntrNoCallback,
+           WriteOnly<ArgIndex<0>>, NoCapture<ArgIndex<0>>],
+          STMATRIX_NAME<Frag, Transposed>.intr>;
+
+foreach transposed = [0, 1] in {
+  foreach frag = NVVM_MMA_OPS.all_stmatrix_ops in {
+    if NVVM_STMATRIX_SUPPORTED<frag, transposed>.ret then {
+      def STMATRIX_NAME<frag, transposed>.record
+        : NVVM_STMATRIX<frag, transposed>;
+    }
+  }
+}
+
 // MAPA
 let IntrProperties = [IntrNoMem, IntrSpeculatable, NoCapture<ArgIndex<0>>] in {
   def int_nvvm_mapa
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 3d010e04824c5..d94be492b0c02 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -3952,7 +3952,10 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
   case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col:
   case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col_stride:
   case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row:
-  case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride: {
+  case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride:
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16:
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16:
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8: {
     Info.opc = ISD::INTRINSIC_VOID;
     Info.memVT = MVT::v2i32;
     Info.ptrVal = I.getArgOperand(0);
@@ -3975,6 +3978,30 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
     return true;
   }
 
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16:
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16:
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8: {
+    Info.opc = ISD::INTRINSIC_VOID;
+    Info.memVT = MVT::i32;
+    Info.ptrVal = I.getArgOperand(0);
+    Info.offset = 0;
+    Info.flags = MachineMemOperand::MOStore;
+    Info.align = Align(4);
+    return true;
+  }
+
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16:
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16:
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8: {
+    Info.opc = ISD::INTRINSIC_VOID;
+    Info.memVT = MVT::v4i32;
+    Info.ptrVal = I.getArgOperand(0);
+    Info.offset = 0;
+    Info.flags = MachineMemOperand::MOStore;
+    Info.align = Align(16);
+    return true;
+  }
+
   case Intrinsic::nvvm_atomic_add_gen_f_cta:
   case Intrinsic::nvvm_atomic_add_gen_f_sys:
   case Intrinsic::nvvm_atomic_add_gen_i_cta:
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 93827be5c2811..1e24bf8ab99e1 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -4597,7 +4597,14 @@ class WMMA_REGINFO<WMMA_REGS r, string op>
 
     !and(!eq(op, "ldmatrix"),
          !eq(ptx_elt_type, "b8x16.b4x16_p64"),
-         !eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]);
+         !eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>],
+
+    !and(!eq(op, "stmatrix"),!eq(ptx_elt_type, "b16"),
+         !eq(geom, "m8n8")) : [hasSM<90>, hasPTX<78>],
+
+    !and(!eq(op, "stmatrix"),
+         !eq(ptx_elt_type, "b8"),
+         !eq(geom, "m16n8")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]);
 
   // template DAGs for instruction inputs/output.
   dag Outs = !dag(outs, ptx_regs, reg_names);
@@ -4878,6 +4885,40 @@ defset list<WMMA_INSTR> LDMATRIXs  = {
   } // transposed
 } // defset
 
+//
+// stmatrix.sync.aligned.m8n8[|.trans][|.shared].b16
+//
+class STMATRIX<WMMA_REGINFO Frag, bit Transposed, string Space>
+  : WMMA_INSTR<STMATRIX_NAME<Frag, Transposed>.record, [!con((ins ADDR:$dst), Frag.Ins)]>,
+    Requires<Frag.Predicates> {
+  // Build PatFrag that only matches particular address space.
+  dag PFOperands = !con((ops node:$dst), !dag(ops, !listsplat(node, !size(Frag.regs)), Frag.reg_names));
+  PatFrag IntrFrag = PatFrag<PFOperands, !foreach(tmp, PFOperands, !subst(ops, Intr, tmp)), 
+                             !cond(!eq(Space, ".shared"): AS_match.shared,
+                                   true: AS_match.generic)>;
+  // Build AS-constrained pattern.
+  let IntrinsicPattern = BuildPatternPF<IntrFrag, Args>.ret;
+  let OutOperandList = (outs);
+  let InOperandList = !con(Args, (ins MmaCode:$ptx));
+  let AsmString = "stmatrix.sync.aligned."
+                  # Frag.geom
+                  # "." # Frag.frag
+                  # !if(Transposed, ".trans", "")
+                  # Space
+                  # "." # Frag.ptx_elt_type
+                  # " [$dst], " # Frag.regstring # ";";
+}
+
+// Create all stmatrix variants
+defset list<WMMA_INSTR> STMATRIXs = {
+  foreach transposed = [false, true] in {foreach space = [".shared", ""] in {
+      foreach frag = NVVM_MMA_OPS.all_stmatrix_ops in
+        if NVVM_STMATRIX_SUPPORTED<frag, transposed>.ret then
+          def : STMATRIX<WMMA_REGINFO<frag, "stmatrix">, transposed, space>;
+    } // space
+  } // transposed
+} // defset
+
 // Constructing non-flat DAGs is still a pain. I can't !subst a dag node with a
 // dag, so the ptx.version must be appended *after* foreach replaces 'ins' with
 // the instruction record.
@@ -4888,7 +4929,7 @@ class MMA_PAT<WMMA_INSTR wi>
         Requires<wi.Predicates>;
 
 // Build intrinsic->instruction patterns for all MMA instructions.
-foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs) in
+foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs, STMATRIXs) in
   def : MMA_PAT<mma>;
 
 multiclass MAPA<string suffix, Intrinsic Intr> {
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py b/llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py
new file mode 100644
index 0000000000000..8f502065345c1
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py
@@ -0,0 +1,14 @@
+# Check all variants of instructions supported by PTX78 on SM90
+# RUN: %python %s --ptx=78 --gpu-arch=90 --aa > %t-ptx78-sm_90.ll
+# RUN: FileCheck %t-ptx78-sm_90.ll < %t-ptx78-sm_90.ll \
+# RUN:           --check-prefixes=PTX78STMATRIX-DAG
+# RUN: llc < %t-ptx78-sm_90.ll -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78 \
+# RUN:           | FileCheck %t-ptx78-sm_90.ll
+# RUN: %if ptxas-12.7 %{                                                  \
+# RUN: llc < %t-ptx78-sm_90.ll -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78 \
+# RUN:           | %ptxas-verify -arch=sm_90                              \
+# RUN: %}
+
+import wmma
+
+wmma.main()
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
index 6ad0a2a5865c4..5c14a54601ed9 100644
--- a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
@@ -1,9 +1,7 @@
 # Check all variants of instructions supported by PTX86 on SM100a
 # RUN: %python %s --ptx=86 --gpu-arch=100 --aa > %t-ptx86-sm_100a.ll
 # RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
-# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
-# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
-# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG
 # RUN: llc < %t-ptx86-sm_100a.ll -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86 \
 # RUN:           | FileCheck %t-ptx86-sm_100a.ll
 # RUN: %if ptxas-12.7 %{                                                  \
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
index 7d9953484da7d..a77f9adddff9c 100644
--- a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
@@ -1,9 +1,7 @@
 # Check all variants of instructions supported by PTX86 on SM101a
 # RUN: %python %s --ptx=86 --gpu-arch=101 --aa > %t-ptx86-sm_101a.ll
 # RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
-# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
-# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
-# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG
 # RUN: llc < %t-ptx86-sm_101a.ll -mtriple=nvptx64 -mcpu=sm_101a -mattr=+ptx86 \
 # RUN:           | FileCheck %t-ptx86-sm_101a.ll
 # RUN: %if ptxas-12.7 %{                                                  \
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
index 7bddf0b6fbb78..8126e64d6cc85 100644
--- a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
@@ -1,9 +1,7 @@
 # Check all variants of instructions supported by PTX86 on SM120a
 # RUN: %python %s --ptx=86 --gpu-arch=120 --aa > %t-ptx86-sm_120a.ll
 # RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
-# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
-# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
-# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG
 # RUN: llc < %t-ptx86-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx86 \
 # RUN:           | FileCheck %t-ptx86-sm_120a.ll
 # RUN: %if ptxas-12.7 %{                                                  \
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index 2ee489670e9e4..3888e9b6b1b8d 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -10,6 +10,7 @@
 from itertools import product
 from string import Template
 
+
 class MMAType:
     def __init__(self, ptx_type):
         self.ptx_type = ptx_type
@@ -176,6 +177,13 @@ def __init__(self, geom, frag, ptx_elt_type):
             "m8n16:x1:b8x16.b4x16_p64": 1,
             "m8n16:x2:b8x16.b4x16_p64": 2,
             "m8n16:x4:b8x16.b4x16_p64": 4,
+            # stmatrix
+            "m8n8:x1:b16": 1,
+            "m8n8:x2:b16": 2,
+            "m8n8:x4:b16": 4,
+            "m16n8:x1:b8": 1,
+            "m16n8:x2:b8": 2,
+            "m16n8:x4:b8": 4,
         }.get(
             "%s:%s:%s" % (geom, frag, ptx_elt_type),
             {
@@ -241,6 +249,13 @@ def make_ldmatrix_ops(geoms, frags, types):
     ]
 
 
+def make_stmatrix_ops(geoms, frags, types):
+    return [
+        MMAFrag(geom, frag, ptx_type)
+        for (geom, frag, ptx_type) in product(geoms, frags, types)
+    ]
+
+
 def get_wmma_ops():
     return (
         make_mma_ops(["m16n16k8"], ["tf32"], [], ["f32"], [])
@@ -315,6 +330,12 @@ def get_ldmatrix_ops():
     )
 
 
+def get_stmatrix_ops():
+    return make_stmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"]) + make_stmatrix_ops(
+        ["m16n8"], ["x1", "x2", "x4"], ["b8"]
+    )
+
+
 def is_wmma_geom_supported(geom):
     # geometries for FP and ints.
     if geom in ["m8n32k16", "m32n8k16"]:
@@ -360,6 +381,14 @@ def is_ldmatrix_geom_supported(geom):
     assert False  # Unexpected geometry.
 
 
+def is_stmatrix_geom_supported(geom):
+    if geom in ["m8n8"]:
+        return ptx_version >= 78 and gpu_arch >= 90
+    elif geom in ["m16n8"]:
+        return ptx_version >= 86 and gpu_arch >= 100 and aa
+    assert False  # Unexpected geometry.
+
+
 def is_ldmatrix_trans_supported(geom, trans):
     if geom in ["m8n8"]:
         return True
@@ -369,6 +398,15 @@ def is_ldmatrix_trans_supported(geom, trans):
         return trans == ""
     assert False  # Unexpected geometry.
 
+
+def is_stmatrix_trans_supported(geom, trans):
+    if geom in ["m8n8"]:
+        return True
+    elif geom in ["m16n8"]:
+        return trans == ".trans"
+    assert False  # Unexpected geometry.
+
+
 def is_type_supported(ptx_type):
     if ptx_type in ["s8", "u8", "s32"]:
         return ptx_version >= 63 and gpu_arch >= 72
@@ -463,6 +501,16 @@ def is_ldmatrix_variant_supported(frag, trans):
     return frag.frag in ["x1", "x2", "x4"]
 
 
+def is_stmatrix_variant_supported(frag, trans):
+    if not (
+        is_type_supported(frag.mma_type.ptx_type)
+        and is_stmatrix_geom_supported(frag.geom)
+        and is_stmatrix_trans_supported(frag.geom, trans)
+    ):
+        return False
+    return frag.frag in ["x1", "x2", "x4"]
+
+
 def make_wmma_slice_ty(frag):
     return [frag.mma_type.llvm_type] * frag.nregs
 
@@ -716,6 +764,61 @@ def gen_ldmatrix_tests():
 
     return generated_items
 
+def gen_stmatrix_tests():
+    stmatrix_template = """
+declare void @${intrinsic}(i8 ${as}* %dst, ${args});
+
+; CHECK-LABEL: .func {{.*}}test_${function}(
+define void @test_${function}(i8 ${as}* %dst, ${args}) {
+; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}]
+; CHECK: {${check_args}}
+  call void @${intrinsic}(i8${as}* %dst, ${args});
+  ret void
+}
+
+; CHECK-LABEL: .func{{.*}}test_${function}_o(
+define void @test_${function}_o(i8 ${as}* %dst, ${args}) {
+; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128],
+; CHECK: {${check_args}}
+  %dst1 = getelementptr i8, i8 ${as}* %dst, i32 128;
+  call void @${intrinsic}(i8 ${as}* %dst1, ${args});
+  ret void
+}
+"""
+    intrinsic_template = (
+        "llvm.nvvm.stmatrix.sync.aligned.${geom}.${frag}${trans}.${itype}.${pspace}"
+    )
+    instruction_template = ("stmatrix.sync.aligned.${geom}.${frag}${trans}${space}.${itype}"
+    )
+    generated_items = []
+
+    for frag, space, trans in product(get_stmatrix_ops(),
+        ["", ".shared"],
+        ["", ".trans"],
+    ):
+        if not is_stmatrix_variant_supported(frag, trans):
+            continue
+
+        params = {
+            "frag": frag.frag,
+            "space": space,"trans": trans,
+            "itype": frag.mma_type.ptx_type,
+            "pspace": get_pspace(space),
+            "as": "addrspace(%d)" % get_aspace(space),
+            "geom": frag.geom,
+        }
+
+        test_params = params
+        test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
+        test_params["function"] = test_params["intrinsic"].replace(".", "_")
+        test_params["instruction"] = Template(instruction_template).substitute(params)
+        test_params["args"] = make_wmma_slice_args(frag)
+        test_params["check_args"] = check_pattern(frag)
+
+        print(Template(stmatrix_template).substitute(test_params))
+        generated_items.append((test_params["intrinsic"], test_params["instruction"]))
+
+    return generated_items
 
 def mma_signature(op):
     if op.a.mma_type.ptx_type == "f16":
@@ -893,6 +996,7 @@ def gen_check_unsupported_ops(items):
 ; NOALTFLOAT-NOT: .{{bf16|tf32}}
 ; NODOUBLE-NOT: .f64
 ; NOLDMATRIX-NOT: ldmatrix.sync.aligned
+; NOSTMATRIX-NOT: stmatrix.sync.aligned
 
 ; M16N16-DAG: m16n16k16.load.{{[ab].*}}.f16.p
 ; M16N16-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
@@ -994,6 +1098,26 @@ def gen_check_unsupported_ops(items):
 ; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32
 ; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64
 
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.trans.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.trans.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.trans.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.trans.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.trans.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.trans.shared.b16
+
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x1.trans.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x2.trans.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x4.trans.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x1.trans.shared.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x2.trans.shared.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x4.trans.shared.b8
+
 ; PTX71MMA-DAG: mma.m8n8k4.row.col.f64
 ; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32
 ; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32
@@ -1039,6 +1163,7 @@ def gen_tests():
     items = gen_wmma_load_tests()
     items += gen_wmma_store_tests()
     items += gen_ldmatrix_tests()
+    items += gen_stmatrix_tests()
     items += gen_wmma_mma_tests()
     items += gen_mma_tests()
     gen_check_unsupported_ops(items)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 45a8904375e2b..8de5932aaf2e3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1990,10 +1990,22 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">,
   let hasVerifier = 1;
 }
 
-def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">, 
-  Ar...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jul 12, 2025

@llvm/pr-subscribers-llvm-ir

Author: Pecco (Pecco-314)

Changes

Add support for the @<!-- -->llvm.nvvm.stmatrix intrinsic series. These correspond to PTX stmatrix operations, as documented in the PTX ISA reference.


Patch is 29.97 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/148377.diff

12 Files Affected:

  • (modified) llvm/include/llvm/IR/IntrinsicsNVVM.td (+65)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+28-1)
  • (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+43-2)
  • (added) llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py (+14)
  • (modified) llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py (+1-3)
  • (modified) llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py (+1-3)
  • (modified) llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py (+1-3)
  • (modified) llvm/test/CodeGen/NVPTX/wmma.py (+125)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+21-18)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp (+43)
  • (modified) mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir (-24)
  • (modified) mlir/test/Target/LLVMIR/nvvmir.mlir (+23)
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 0375f29ad8906..aad21fd4cba1c 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -331,6 +331,11 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
     !eq(gf,"m8n16:x2") : !listsplat(llvm_i32_ty, 2),
     !eq(gf,"m8n16:x4") : !listsplat(llvm_i32_ty, 4),
 
+    // stmatrix b8 -> s32 @ m16n8
+    !eq(gf,"m16n8:x1") : !listsplat(llvm_i32_ty, 1),
+    !eq(gf,"m16n8:x2") : !listsplat(llvm_i32_ty, 2),
+    !eq(gf,"m16n8:x4") : !listsplat(llvm_i32_ty, 4),
+
   );
 }
 
@@ -403,6 +408,17 @@ class LDMATRIX_NAME<WMMA_REGS Frag, int Trans> {
                   !subst("llvm.", "int_", intr));
 }
 
+class STMATRIX_NAME<WMMA_REGS Frag, int Trans> {
+  string intr = "llvm.nvvm.stmatrix.sync.aligned"
+                # "." # Frag.geom
+                # "." # Frag.frag
+                # !if(Trans, ".trans", "")
+                # "." # Frag.ptx_elt_type
+                ;
+  string record = !subst(".", "_",
+                  !subst("llvm.", "int_", intr));
+}
+
 // Generates list of 4-tuples of WMMA_REGS representing a valid MMA op.
 //   Geom: list of supported geometries.
 //   TypeN: PTX type of the corresponding fragment's element.
@@ -443,6 +459,16 @@ class LDMATRIX_OPS<list<string> Geom, list<string> Frags, list<string> Types> {
    list<string> ops = !foreach(x, ret, x.gft);
 }
 
+class STMATRIX_OPS<list<string> Geom, list<string> Frags, list<string> Types> {
+  list<WMMA_REGS> ret =
+     !foldl([]<WMMA_REGS>, Geom, t1, geom, !listconcat(t1,
+     !foldl([]<WMMA_REGS>, Frags, t2, frag, !listconcat(t2,
+     !foldl([]<WMMA_REGS>, Types, t3, type, !listconcat(t3,
+            [WMMA_REGS<geom, frag, type>]))))));
+   // Debugging aid for readable representation of the list above.
+   list<string> ops = !foreach(x, ret, x.gft);
+}
+
 // Creates list of valid combinations of fragments. This is the main list that
 // drives generation of corresponding intrinsics and instructions.
 class NVVM_MMA_OPS {
@@ -537,9 +563,18 @@ class NVVM_MMA_OPS {
   list<WMMA_REGS> ldmatrix_geom_m8n16_ops = LDMATRIX_OPS<
     ["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret;
 
+  list<WMMA_REGS> stmatrix_b16_ops = STMATRIX_OPS<
+    ["m8n8"], ["x1", "x2", "x4"], ["b16"]>.ret;
+
+  list<WMMA_REGS> stmatrix_b8_ops = STMATRIX_OPS<
+    ["m16n8"], ["x1", "x2", "x4"], ["b8"]>.ret;
+
   list<WMMA_REGS> all_ldmatrix_ops = !listconcat(ldmatrix_b16_ops,
                                                  ldmatrix_geom_m16n16_ops,
                                                  ldmatrix_geom_m8n16_ops);
+
+  list<WMMA_REGS> all_stmatrix_ops = !listconcat(stmatrix_b16_ops,
+                                                 stmatrix_b8_ops);
 }
 
 def NVVM_MMA_OPS : NVVM_MMA_OPS;
@@ -680,6 +715,19 @@ class NVVM_LDMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> {
   );
 }
 
+// Returns true if the fragment is valid for stmatrix ops is supported;
+// false otherwise.
+class NVVM_STMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> {
+  string g = frag.geom;
+  string t = frag.ptx_elt_type;
+
+  bit ret = !cond(
+    !and(!eq(g, "m8n8"), !eq(t, "b16")): true,
+    !and(!eq(g, "m16n8"), !eq(t, "b8"), !eq(trans, 1)): true,
+    true: false
+  );
+}
+
 class SHFL_INFO<bit sync, string mode, string type, bit return_pred> {
   string Suffix = !if(sync, "sync_", "")
                   # mode # "_"
@@ -1969,6 +2017,23 @@ foreach transposed = [0, 1] in {
   }
 }
 
+// STMATRIX
+class NVVM_STMATRIX<WMMA_REGS Frag, int Transposed>
+  : Intrinsic<[],
+          !listconcat([llvm_anyptr_ty], Frag.regs),
+          [IntrWriteMem, IntrArgMemOnly, IntrNoCallback,
+           WriteOnly<ArgIndex<0>>, NoCapture<ArgIndex<0>>],
+          STMATRIX_NAME<Frag, Transposed>.intr>;
+
+foreach transposed = [0, 1] in {
+  foreach frag = NVVM_MMA_OPS.all_stmatrix_ops in {
+    if NVVM_STMATRIX_SUPPORTED<frag, transposed>.ret then {
+      def STMATRIX_NAME<frag, transposed>.record
+        : NVVM_STMATRIX<frag, transposed>;
+    }
+  }
+}
+
 // MAPA
 let IntrProperties = [IntrNoMem, IntrSpeculatable, NoCapture<ArgIndex<0>>] in {
   def int_nvvm_mapa
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 3d010e04824c5..d94be492b0c02 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -3952,7 +3952,10 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
   case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col:
   case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col_stride:
   case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row:
-  case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride: {
+  case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride:
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16:
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16:
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8: {
     Info.opc = ISD::INTRINSIC_VOID;
     Info.memVT = MVT::v2i32;
     Info.ptrVal = I.getArgOperand(0);
@@ -3975,6 +3978,30 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
     return true;
   }
 
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16:
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16:
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8: {
+    Info.opc = ISD::INTRINSIC_VOID;
+    Info.memVT = MVT::i32;
+    Info.ptrVal = I.getArgOperand(0);
+    Info.offset = 0;
+    Info.flags = MachineMemOperand::MOStore;
+    Info.align = Align(4);
+    return true;
+  }
+
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16:
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16:
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8: {
+    Info.opc = ISD::INTRINSIC_VOID;
+    Info.memVT = MVT::v4i32;
+    Info.ptrVal = I.getArgOperand(0);
+    Info.offset = 0;
+    Info.flags = MachineMemOperand::MOStore;
+    Info.align = Align(16);
+    return true;
+  }
+
   case Intrinsic::nvvm_atomic_add_gen_f_cta:
   case Intrinsic::nvvm_atomic_add_gen_f_sys:
   case Intrinsic::nvvm_atomic_add_gen_i_cta:
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 93827be5c2811..1e24bf8ab99e1 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -4597,7 +4597,14 @@ class WMMA_REGINFO<WMMA_REGS r, string op>
 
     !and(!eq(op, "ldmatrix"),
          !eq(ptx_elt_type, "b8x16.b4x16_p64"),
-         !eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]);
+         !eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>],
+
+    !and(!eq(op, "stmatrix"),!eq(ptx_elt_type, "b16"),
+         !eq(geom, "m8n8")) : [hasSM<90>, hasPTX<78>],
+
+    !and(!eq(op, "stmatrix"),
+         !eq(ptx_elt_type, "b8"),
+         !eq(geom, "m16n8")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]);
 
   // template DAGs for instruction inputs/output.
   dag Outs = !dag(outs, ptx_regs, reg_names);
@@ -4878,6 +4885,40 @@ defset list<WMMA_INSTR> LDMATRIXs  = {
   } // transposed
 } // defset
 
+//
+// stmatrix.sync.aligned.m8n8[|.trans][|.shared].b16
+//
+class STMATRIX<WMMA_REGINFO Frag, bit Transposed, string Space>
+  : WMMA_INSTR<STMATRIX_NAME<Frag, Transposed>.record, [!con((ins ADDR:$dst), Frag.Ins)]>,
+    Requires<Frag.Predicates> {
+  // Build PatFrag that only matches particular address space.
+  dag PFOperands = !con((ops node:$dst), !dag(ops, !listsplat(node, !size(Frag.regs)), Frag.reg_names));
+  PatFrag IntrFrag = PatFrag<PFOperands, !foreach(tmp, PFOperands, !subst(ops, Intr, tmp)), 
+                             !cond(!eq(Space, ".shared"): AS_match.shared,
+                                   true: AS_match.generic)>;
+  // Build AS-constrained pattern.
+  let IntrinsicPattern = BuildPatternPF<IntrFrag, Args>.ret;
+  let OutOperandList = (outs);
+  let InOperandList = !con(Args, (ins MmaCode:$ptx));
+  let AsmString = "stmatrix.sync.aligned."
+                  # Frag.geom
+                  # "." # Frag.frag
+                  # !if(Transposed, ".trans", "")
+                  # Space
+                  # "." # Frag.ptx_elt_type
+                  # " [$dst], " # Frag.regstring # ";";
+}
+
+// Create all stmatrix variants
+defset list<WMMA_INSTR> STMATRIXs = {
+  foreach transposed = [false, true] in {foreach space = [".shared", ""] in {
+      foreach frag = NVVM_MMA_OPS.all_stmatrix_ops in
+        if NVVM_STMATRIX_SUPPORTED<frag, transposed>.ret then
+          def : STMATRIX<WMMA_REGINFO<frag, "stmatrix">, transposed, space>;
+    } // space
+  } // transposed
+} // defset
+
 // Constructing non-flat DAGs is still a pain. I can't !subst a dag node with a
 // dag, so the ptx.version must be appended *after* foreach replaces 'ins' with
 // the instruction record.
@@ -4888,7 +4929,7 @@ class MMA_PAT<WMMA_INSTR wi>
         Requires<wi.Predicates>;
 
 // Build intrinsic->instruction patterns for all MMA instructions.
-foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs) in
+foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs, STMATRIXs) in
   def : MMA_PAT<mma>;
 
 multiclass MAPA<string suffix, Intrinsic Intr> {
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py b/llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py
new file mode 100644
index 0000000000000..8f502065345c1
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py
@@ -0,0 +1,14 @@
+# Check all variants of instructions supported by PTX78 on SM90
+# RUN: %python %s --ptx=78 --gpu-arch=90 --aa > %t-ptx78-sm_90.ll
+# RUN: FileCheck %t-ptx78-sm_90.ll < %t-ptx78-sm_90.ll \
+# RUN:           --check-prefixes=PTX78STMATRIX-DAG
+# RUN: llc < %t-ptx78-sm_90.ll -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78 \
+# RUN:           | FileCheck %t-ptx78-sm_90.ll
+# RUN: %if ptxas-12.7 %{                                                  \
+# RUN: llc < %t-ptx78-sm_90.ll -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78 \
+# RUN:           | %ptxas-verify -arch=sm_90                              \
+# RUN: %}
+
+import wmma
+
+wmma.main()
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
index 6ad0a2a5865c4..5c14a54601ed9 100644
--- a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
@@ -1,9 +1,7 @@
 # Check all variants of instructions supported by PTX86 on SM100a
 # RUN: %python %s --ptx=86 --gpu-arch=100 --aa > %t-ptx86-sm_100a.ll
 # RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
-# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
-# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
-# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG
 # RUN: llc < %t-ptx86-sm_100a.ll -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86 \
 # RUN:           | FileCheck %t-ptx86-sm_100a.ll
 # RUN: %if ptxas-12.7 %{                                                  \
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
index 7d9953484da7d..a77f9adddff9c 100644
--- a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
@@ -1,9 +1,7 @@
 # Check all variants of instructions supported by PTX86 on SM101a
 # RUN: %python %s --ptx=86 --gpu-arch=101 --aa > %t-ptx86-sm_101a.ll
 # RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
-# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
-# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
-# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG
 # RUN: llc < %t-ptx86-sm_101a.ll -mtriple=nvptx64 -mcpu=sm_101a -mattr=+ptx86 \
 # RUN:           | FileCheck %t-ptx86-sm_101a.ll
 # RUN: %if ptxas-12.7 %{                                                  \
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
index 7bddf0b6fbb78..8126e64d6cc85 100644
--- a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
@@ -1,9 +1,7 @@
 # Check all variants of instructions supported by PTX86 on SM120a
 # RUN: %python %s --ptx=86 --gpu-arch=120 --aa > %t-ptx86-sm_120a.ll
 # RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
-# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
-# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
-# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG
 # RUN: llc < %t-ptx86-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx86 \
 # RUN:           | FileCheck %t-ptx86-sm_120a.ll
 # RUN: %if ptxas-12.7 %{                                                  \
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index 2ee489670e9e4..3888e9b6b1b8d 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -10,6 +10,7 @@
 from itertools import product
 from string import Template
 
+
 class MMAType:
     def __init__(self, ptx_type):
         self.ptx_type = ptx_type
@@ -176,6 +177,13 @@ def __init__(self, geom, frag, ptx_elt_type):
             "m8n16:x1:b8x16.b4x16_p64": 1,
             "m8n16:x2:b8x16.b4x16_p64": 2,
             "m8n16:x4:b8x16.b4x16_p64": 4,
+            # stmatrix
+            "m8n8:x1:b16": 1,
+            "m8n8:x2:b16": 2,
+            "m8n8:x4:b16": 4,
+            "m16n8:x1:b8": 1,
+            "m16n8:x2:b8": 2,
+            "m16n8:x4:b8": 4,
         }.get(
             "%s:%s:%s" % (geom, frag, ptx_elt_type),
             {
@@ -241,6 +249,13 @@ def make_ldmatrix_ops(geoms, frags, types):
     ]
 
 
+def make_stmatrix_ops(geoms, frags, types):
+    return [
+        MMAFrag(geom, frag, ptx_type)
+        for (geom, frag, ptx_type) in product(geoms, frags, types)
+    ]
+
+
 def get_wmma_ops():
     return (
         make_mma_ops(["m16n16k8"], ["tf32"], [], ["f32"], [])
@@ -315,6 +330,12 @@ def get_ldmatrix_ops():
     )
 
 
+def get_stmatrix_ops():
+    return make_stmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"]) + make_stmatrix_ops(
+        ["m16n8"], ["x1", "x2", "x4"], ["b8"]
+    )
+
+
 def is_wmma_geom_supported(geom):
     # geometries for FP and ints.
     if geom in ["m8n32k16", "m32n8k16"]:
@@ -360,6 +381,14 @@ def is_ldmatrix_geom_supported(geom):
     assert False  # Unexpected geometry.
 
 
+def is_stmatrix_geom_supported(geom):
+    if geom in ["m8n8"]:
+        return ptx_version >= 78 and gpu_arch >= 90
+    elif geom in ["m16n8"]:
+        return ptx_version >= 86 and gpu_arch >= 100 and aa
+    assert False  # Unexpected geometry.
+
+
 def is_ldmatrix_trans_supported(geom, trans):
     if geom in ["m8n8"]:
         return True
@@ -369,6 +398,15 @@ def is_ldmatrix_trans_supported(geom, trans):
         return trans == ""
     assert False  # Unexpected geometry.
 
+
+def is_stmatrix_trans_supported(geom, trans):
+    if geom in ["m8n8"]:
+        return True
+    elif geom in ["m16n8"]:
+        return trans == ".trans"
+    assert False  # Unexpected geometry.
+
+
 def is_type_supported(ptx_type):
     if ptx_type in ["s8", "u8", "s32"]:
         return ptx_version >= 63 and gpu_arch >= 72
@@ -463,6 +501,16 @@ def is_ldmatrix_variant_supported(frag, trans):
     return frag.frag in ["x1", "x2", "x4"]
 
 
+def is_stmatrix_variant_supported(frag, trans):
+    if not (
+        is_type_supported(frag.mma_type.ptx_type)
+        and is_stmatrix_geom_supported(frag.geom)
+        and is_stmatrix_trans_supported(frag.geom, trans)
+    ):
+        return False
+    return frag.frag in ["x1", "x2", "x4"]
+
+
 def make_wmma_slice_ty(frag):
     return [frag.mma_type.llvm_type] * frag.nregs
 
@@ -716,6 +764,61 @@ def gen_ldmatrix_tests():
 
     return generated_items
 
+def gen_stmatrix_tests():
+    stmatrix_template = """
+declare void @${intrinsic}(i8 ${as}* %dst, ${args});
+
+; CHECK-LABEL: .func {{.*}}test_${function}(
+define void @test_${function}(i8 ${as}* %dst, ${args}) {
+; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}]
+; CHECK: {${check_args}}
+  call void @${intrinsic}(i8${as}* %dst, ${args});
+  ret void
+}
+
+; CHECK-LABEL: .func{{.*}}test_${function}_o(
+define void @test_${function}_o(i8 ${as}* %dst, ${args}) {
+; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128],
+; CHECK: {${check_args}}
+  %dst1 = getelementptr i8, i8 ${as}* %dst, i32 128;
+  call void @${intrinsic}(i8 ${as}* %dst1, ${args});
+  ret void
+}
+"""
+    intrinsic_template = (
+        "llvm.nvvm.stmatrix.sync.aligned.${geom}.${frag}${trans}.${itype}.${pspace}"
+    )
+    instruction_template = ("stmatrix.sync.aligned.${geom}.${frag}${trans}${space}.${itype}"
+    )
+    generated_items = []
+
+    for frag, space, trans in product(get_stmatrix_ops(),
+        ["", ".shared"],
+        ["", ".trans"],
+    ):
+        if not is_stmatrix_variant_supported(frag, trans):
+            continue
+
+        params = {
+            "frag": frag.frag,
+            "space": space,"trans": trans,
+            "itype": frag.mma_type.ptx_type,
+            "pspace": get_pspace(space),
+            "as": "addrspace(%d)" % get_aspace(space),
+            "geom": frag.geom,
+        }
+
+        test_params = params
+        test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
+        test_params["function"] = test_params["intrinsic"].replace(".", "_")
+        test_params["instruction"] = Template(instruction_template).substitute(params)
+        test_params["args"] = make_wmma_slice_args(frag)
+        test_params["check_args"] = check_pattern(frag)
+
+        print(Template(stmatrix_template).substitute(test_params))
+        generated_items.append((test_params["intrinsic"], test_params["instruction"]))
+
+    return generated_items
 
 def mma_signature(op):
     if op.a.mma_type.ptx_type == "f16":
@@ -893,6 +996,7 @@ def gen_check_unsupported_ops(items):
 ; NOALTFLOAT-NOT: .{{bf16|tf32}}
 ; NODOUBLE-NOT: .f64
 ; NOLDMATRIX-NOT: ldmatrix.sync.aligned
+; NOSTMATRIX-NOT: stmatrix.sync.aligned
 
 ; M16N16-DAG: m16n16k16.load.{{[ab].*}}.f16.p
 ; M16N16-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
@@ -994,6 +1098,26 @@ def gen_check_unsupported_ops(items):
 ; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32
 ; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64
 
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.trans.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.trans.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.trans.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.trans.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.trans.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.trans.shared.b16
+
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x1.trans.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x2.trans.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x4.trans.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x1.trans.shared.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x2.trans.shared.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x4.trans.shared.b8
+
 ; PTX71MMA-DAG: mma.m8n8k4.row.col.f64
 ; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32
 ; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32
@@ -1039,6 +1163,7 @@ def gen_tests():
     items = gen_wmma_load_tests()
     items += gen_wmma_store_tests()
     items += gen_ldmatrix_tests()
+    items += gen_stmatrix_tests()
     items += gen_wmma_mma_tests()
     items += gen_mma_tests()
     gen_check_unsupported_ops(items)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 45a8904375e2b..8de5932aaf2e3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1990,10 +1990,22 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">,
   let hasVerifier = 1;
 }
 
-def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">, 
-  Ar...
[truncated]

@Pecco-314 Pecco-314 changed the title [MLIR][NVVM][NVGPU] Support intrinsics about stmatrix [MLIR][NVVM][NVPTX] Support intrinsics about stmatrix Jul 12, 2025
@lezcano
Copy link

lezcano commented Jul 12, 2025

There are equivalent instrucctions for ldmatrix, will these go in a different PR?
At any rate, thank you for this contribution!

@Pecco-314
Copy link
Author

Pecco-314 commented Jul 12, 2025

There are equivalent instrucctions for ldmatrix, will these go in a different PR? At any rate, thank you for this contribution!

Yes. I note that ldmatrix intrinsics were already implemented (including new Blackwell ops). I'd like to propose another PR later to enable the NVVM ops to generate the corresponding ldmatrix intrinsics.

@schwarzschild-radius
Copy link
Contributor

@Pecco-314 Can you please split the PR into two separate ones? One with NVPTX and one with MLIR changes. Both requires two different reviews and approvals from two different code owners

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants