From fa44a32bf05036e9ce707317f0650ae165b74a94 Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Thu, 19 Sep 2024 14:03:51 +0530 Subject: [PATCH] [LLVMGPU] Explicitly set configs for vector distribution pipeline lowering tests (#18553) This patch does two things: - Before this patch, tests in pipeline_vector_distribute.mlir run select-lowering-config as well as the lowering pipeline together. This restricts us from testing lowering for different configs for the same kernel. This patch explicitly sets configuration for kernels in these tests. For configuration, we already have tests in config_vector_distribute.mlir which test kernel config logic for vector distribution. - Tests for gfx940 and gfx1100 were in the same file. Since we set kernel configuration explicitly, running a test with a gfx1100 intrinsic with gfx940 test target chip would cause errors. This patch splits these tests into their own files. --- .../Codegen/LLVMGPU/test/ROCDL/BUILD.bazel | 3 +- .../Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt | 3 +- .../pipeline_vector_distribute_gfx1100.mlir | 96 ++++++ ...=> pipeline_vector_distribute_gfx940.mlir} | 282 +++++------------- 4 files changed, 174 insertions(+), 210 deletions(-) create mode 100644 compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx1100.mlir rename compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/{pipeline_vector_distribute.mlir => pipeline_vector_distribute_gfx940.mlir} (70%) diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel index 978d1cf3a896..4c25ee453397 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel @@ -23,7 +23,8 @@ iree_lit_test_suite( "config_user_vector_distribute.mlir", "lowering_scalar_dispatch.mlir", "pipeline_tile_and_fuse.mlir", - "pipeline_vector_distribute.mlir", + "pipeline_vector_distribute_gfx940.mlir", + "pipeline_vector_distribute_gfx1100.mlir", "pipeline_warp_reduction.mlir", ], include = ["*.mlir"], diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt index b4c1bf6e43b0..fb1d8edef4e7 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt @@ -19,7 +19,8 @@ iree_lit_test_suite( "config_vector_distribute.mlir" "lowering_scalar_dispatch.mlir" "pipeline_tile_and_fuse.mlir" - "pipeline_vector_distribute.mlir" + "pipeline_vector_distribute_gfx1100.mlir" + "pipeline_vector_distribute_gfx940.mlir" "pipeline_warp_reduction.mlir" TOOLS FileCheck diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx1100.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx1100.mlir new file mode 100644 index 000000000000..505c7b62cbd8 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx1100.mlir @@ -0,0 +1,96 @@ +// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx1100 \ +// RUN: --iree-codegen-llvmgpu-use-vector-distribution --iree-llvmgpu-enable-prefetch=true \ +// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-llvmgpu-lower-executable-target)))))" \ +// RUN: %s | FileCheck %s + +#config = #iree_codegen.lowering_config +#translation = #iree_codegen.translation_info, mma_schedule = #iree_gpu.mma_schedule, subgroup_m_count = 2, subgroup_n_count = 2>}> + +#pipeline_layout = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding +]> +hal.executable @matmul_256x256x256_f16_f32 { +hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { + hal.executable.export @matmul_256x256x256_f16_f32 layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index): + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2 + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @matmul_256x256x256_f16_f32() attributes {translation_info = #translation} { + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<256x256xf16> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<256x256xf16> + %5 = tensor.empty() : tensor<256x256xf32> + %6 = linalg.fill {lowering_config = #config} ins(%cst : f32) outs(%5 : tensor<256x256xf32>) -> tensor<256x256xf32> + %7 = linalg.matmul {lowering_config = #config} ins(%3, %4 : tensor<256x256xf16>, tensor<256x256xf16>) outs(%6 : tensor<256x256xf32>) -> tensor<256x256xf32> + flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xf32> -> !flow.dispatch.tensor> + return + } + } +} +} + +// CHECK-LABEL: func.func @matmul_256x256x256_f16_f32 +// CHECK: scf.for {{.*}} = %c0 to %c256 step %c128 iter_args({{.*}}) -> (vector<2x2x8x1x1x1xf32>) +// Each subgroup handles 2 * 2 tiles, and for each tile we accumulate 8 times +// along the K dimension. So in total 32 wmma ops. +// CHECK-COUNT-32: amdgpu.wmma {{.*}} : vector<16xf16>, vector<16xf16>, vector<8xf32> +// CHECK: scf.yield %{{.+}} : vector<2x2x8x1x1x1xf32> +// Since each subgroup handles 2 * 2 tiles, and for each tile, each lane holds 4 values. +// we will have 32 writes. We cannot do contiguous writes since the outputs columns has interleaved +// thread ids. +// CHECK-COUNT-32: vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<1x1xf32>, memref<256x256xf32, #hal.descriptor_type> + +// ----- + +#config = #iree_codegen.lowering_config +#translation = #iree_codegen.translation_info, mma_schedule = #iree_gpu.mma_schedule, subgroup_m_count = 2, subgroup_n_count = 2>}> + +#pipeline_layout = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding +]> +hal.executable @matmul_256x256x256_f16_f16 { +hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { + hal.executable.export @matmul_256x256x256_f16_f16 layout(#pipeline_layout) { + ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index): + %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2 + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @matmul_256x256x256_f16_f16() attributes {translation_info = #translation} { + %cst = arith.constant 0.000000e+00 : f16 + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<256x256xf16> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<256x256xf16> + %5 = tensor.empty() : tensor<256x256xf16> + %6 = linalg.fill {lowering_config = #config} ins(%cst : f16) outs(%5 : tensor<256x256xf16>) -> tensor<256x256xf16> + %7 = linalg.matmul {lowering_config = #config} ins(%3, %4 : tensor<256x256xf16>, tensor<256x256xf16>) outs(%6 : tensor<256x256xf16>) -> tensor<256x256xf16> + flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xf16> -> !flow.dispatch.tensor> + return + } + } +} +} + +// CHECK-LABEL: func.func @matmul_256x256x256_f16_f16 +// CHECK: scf.for {{.*}} = %c0 to %c256 step %c128 iter_args({{.*}}) -> (vector<2x2x16x1x1x1xf16>) +// Each subgroup handles 2 * 2 tiles, and for each tile we accumulate 8 times +// along the K dimension. So in total 32 wmma ops. +// CHECK-COUNT-32: amdgpu.wmma {{.*}} : vector<16xf16>, vector<16xf16>, vector<16xf16> +// CHECK: scf.yield %{{.+}} : vector<2x2x16x1x1x1xf16> +// Since each subgroup handles 2 * 2 tiles, and for each tile, each lane holds 4 values. +// we will have 32 writes. We cannot do contiguous writes since the outputs columns has interleaved +// thread ids. +// CHECK-COUNT-32: vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<1x1xf16>, memref<256x256xf16, #hal.descriptor_type> diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir similarity index 70% rename from compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir rename to compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir index 8a6ffe5ee03d..86e7f0b15242 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir @@ -1,16 +1,10 @@ -// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx942 \ +// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx940 \ // RUN: --iree-codegen-llvmgpu-use-vector-distribution --iree-llvmgpu-enable-prefetch=true \ -// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-llvmgpu-select-lowering-strategy, func.func(iree-llvmgpu-lower-executable-target)))))" \ +// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-llvmgpu-lower-executable-target)))))" \ // RUN: %s | FileCheck %s -// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx1100 \ -// RUN: --iree-codegen-llvmgpu-use-vector-distribution --iree-llvmgpu-enable-prefetch=true \ -// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-llvmgpu-select-lowering-strategy, func.func(iree-llvmgpu-lower-executable-target)))))" \ -// RUN: %s | FileCheck %s --check-prefix=RDNA3 - -// TODO: This test is still using the legacy LLVMGPU kernel config. This needs -// to be migrated to the rocdl heuristics, but for now is just physically -// located here. +#config = #iree_codegen.lowering_config +#translation = #iree_codegen.translation_info, mma_schedule = #iree_gpu.mma_schedule, subgroup_m_count = 2, subgroup_n_count = 2>}> #pipeline_layout = #hal.pipeline.layout, @@ -25,7 +19,7 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { hal.return %x, %y, %z : index, index, index } builtin.module { - func.func @matmul_256x256x256_f16_f32() { + func.func @matmul_256x256x256_f16_f32() attributes {translation_info = #translation} { %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> @@ -34,8 +28,8 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<256x256xf16> %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<256x256xf16> %5 = tensor.empty() : tensor<256x256xf32> - %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<256x256xf32>) -> tensor<256x256xf32> - %7 = linalg.matmul ins(%3, %4 : tensor<256x256xf16>, tensor<256x256xf16>) outs(%6 : tensor<256x256xf32>) -> tensor<256x256xf32> + %6 = linalg.fill {lowering_config = #config} ins(%cst : f32) outs(%5 : tensor<256x256xf32>) -> tensor<256x256xf32> + %7 = linalg.matmul {lowering_config = #config} ins(%3, %4 : tensor<256x256xf16>, tensor<256x256xf16>) outs(%6 : tensor<256x256xf32>) -> tensor<256x256xf32> flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xf32> -> !flow.dispatch.tensor> return } @@ -45,13 +39,7 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { // Basic pipeline test to make sure it generates the instructions we expect. -// CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info, -// CHECK-SAME: subgroup_m_count = 2, subgroup_n_count = 2> - // CHECK-LABEL: func.func @matmul_256x256x256_f16_f32() -// CHECK-SAME: translation_info = #[[$TRANSLATION]] // CHECK: scf.for {{.*}} = %c0 to %c256 step %c128 iter_args({{.*}}) -> (vector<2x2x1x1x4x1xf32>) // Each subgroup handles 2 * 2 tiles, and for each tile we accumulate 8 times // along the K dimension. So in total 32 mfma ops. @@ -61,6 +49,9 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { // ----- +#config = #iree_codegen.lowering_config +#translation = #iree_codegen.translation_info, mma_schedule = #iree_gpu.mma_schedule, subgroup_m_count = 2, subgroup_n_count = 2>}> + #pipeline_layout = #hal.pipeline.layout, #hal.pipeline.binding, @@ -74,7 +65,7 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { hal.return %x, %y, %z : index, index, index } builtin.module { - func.func @matmul_256x256x256_f16_f16() { + func.func @matmul_256x256x256_f16_f16() attributes {translation_info = #translation} { %cst = arith.constant 0.000000e+00 : f16 %c0 = arith.constant 0 : index %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> @@ -83,8 +74,8 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<256x256xf16> %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<256x256xf16> %5 = tensor.empty() : tensor<256x256xf16> - %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<256x256xf16>) -> tensor<256x256xf16> - %7 = linalg.matmul ins(%3, %4 : tensor<256x256xf16>, tensor<256x256xf16>) outs(%6 : tensor<256x256xf16>) -> tensor<256x256xf16> + %6 = linalg.fill {lowering_config = #config} ins(%cst : f16) outs(%5 : tensor<256x256xf16>) -> tensor<256x256xf16> + %7 = linalg.matmul {lowering_config = #config} ins(%3, %4 : tensor<256x256xf16>, tensor<256x256xf16>) outs(%6 : tensor<256x256xf16>) -> tensor<256x256xf16> flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xf16> -> !flow.dispatch.tensor> return } @@ -92,13 +83,7 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { } } -// CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info, -// CHECK-SAME: subgroup_m_count = 2, subgroup_n_count = 2> - // CHECK-LABEL: func.func @matmul_256x256x256_f16_f16() -// CHECK-SAME: translation_info = #[[$TRANSLATION]] // CHECK: scf.for {{.*}} = %c0 to %c256 step %c128 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<2x2x1x1x4x1xf16>) // CHECK: arith.extf %[[ARG]] {{.*}} : vector<2x2x1x1x4x1xf16> to vector<2x2x1x1x4x1xf32> // CHECK-COUNT-32: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> @@ -108,6 +93,9 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { // ----- +#config = #iree_codegen.lowering_config +#translation = #iree_codegen.translation_info, mma_schedule = #iree_gpu.mma_schedule, subgroup_m_count = 1, subgroup_n_count = 4>}> + #pipeline_layout = #hal.pipeline.layout, #hal.pipeline.binding, @@ -121,7 +109,7 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { hal.return %x, %y, %z : index, index, index } builtin.module { - func.func @expanded_matmul_transpose_b() { + func.func @expanded_matmul_transpose_b() attributes {translation_info = #translation} { %c0 = arith.constant 0 : index %cst = arith.constant 0.000000e+00 : f16 %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) @@ -136,14 +124,15 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { : !flow.dispatch.tensor> -> tensor<10x64x2048xf16> %5 = tensor.empty() : tensor<2x10x64x64xf16> - %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2x10x64x64xf16>) -> tensor<2x10x64x64xf16> + %6 = linalg.fill {lowering_config = #config} ins(%cst : f16) outs(%5 : tensor<2x10x64x64xf16>) -> tensor<2x10x64x64xf16> %7 = linalg.generic { indexing_maps = [ affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)> ], - iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"] + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"], + lowering_config = #config } ins(%3, %4 : tensor<2x64x2048xf16>, tensor<10x64x2048xf16>) outs(%6 : tensor<2x10x64x64xf16>) { ^bb0(%lhs: f16, %rhs: f16, %out: f16): %mul = arith.mulf %lhs, %rhs : f16 @@ -159,11 +148,7 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { } } -// CHECK: #[[TRANSLATION:.+]] = #iree_codegen.translation_info (vector<4x1x1x1x4x1xf16>) @@ -178,6 +163,9 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { // Basic f8, f8 -> f32 matmul. +#config = #iree_codegen.lowering_config +#translation = #iree_codegen.translation_info, mma_schedule = #iree_gpu.mma_schedule, subgroup_m_count = 2, subgroup_n_count = 2>}> + #pipeline_layout = #hal.pipeline.layout, #hal.pipeline.binding, @@ -191,7 +179,7 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { hal.return %x, %y, %z : index, index, index } builtin.module { - func.func @matmul_256x256x256_f8_f32() { + func.func @matmul_256x256x256_f8_f32() attributes {translation_info = #translation} { %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> @@ -200,8 +188,8 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<256x256xf8E4M3FNUZ> %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<256x256xf8E4M3FNUZ> %5 = tensor.empty() : tensor<256x256xf32> - %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<256x256xf32>) -> tensor<256x256xf32> - %7 = linalg.matmul ins(%3, %4 : tensor<256x256xf8E4M3FNUZ>, tensor<256x256xf8E4M3FNUZ>) outs(%6 : tensor<256x256xf32>) -> tensor<256x256xf32> + %6 = linalg.fill {lowering_config = #config} ins(%cst : f32) outs(%5 : tensor<256x256xf32>) -> tensor<256x256xf32> + %7 = linalg.matmul {lowering_config = #config} ins(%3, %4 : tensor<256x256xf8E4M3FNUZ>, tensor<256x256xf8E4M3FNUZ>) outs(%6 : tensor<256x256xf32>) -> tensor<256x256xf32> flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xf32> -> !flow.dispatch.tensor> return } @@ -211,13 +199,7 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { // Make sure it generates the mfma instructions we expect for f8 inputs. -// CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info, -// CHECK-SAME: subgroup_m_count = 2, subgroup_n_count = 2> - // CHECK-LABEL: func.func @matmul_256x256x256_f8_f32() -// CHECK-SAME: translation_info = #[[$TRANSLATION]] // Each subgroup handles 2 * 2 tiles, and for each tile we accumulate 8 times // along the K dimension. So in total 32 mfma ops. // CHECK-COUNT-32: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32> @@ -227,6 +209,9 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { // Basic i8, i8 -> i32 matmul. +#config = #iree_codegen.lowering_config +#translation = #iree_codegen.translation_info, mma_schedule = #iree_gpu.mma_schedule, subgroup_m_count = 2, subgroup_n_count = 2>}> + #pipeline_layout = #hal.pipeline.layout, #hal.pipeline.binding, @@ -240,7 +225,7 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { hal.return %x, %y, %z : index, index, index } builtin.module { - func.func @matmul_256x256x256_i8_i32() { + func.func @matmul_256x256x256_i8_i32() attributes {translation_info = #translation} { %cst = arith.constant 0 : i32 %c0 = arith.constant 0 : index %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> @@ -249,8 +234,8 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<256x256xi8> %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<256x256xi8> %5 = tensor.empty() : tensor<256x256xi32> - %6 = linalg.fill ins(%cst : i32) outs(%5 : tensor<256x256xi32>) -> tensor<256x256xi32> - %7 = linalg.matmul ins(%3, %4 : tensor<256x256xi8>, tensor<256x256xi8>) outs(%6 : tensor<256x256xi32>) -> tensor<256x256xi32> + %6 = linalg.fill {lowering_config = #config} ins(%cst : i32) outs(%5 : tensor<256x256xi32>) -> tensor<256x256xi32> + %7 = linalg.matmul {lowering_config = #config} ins(%3, %4 : tensor<256x256xi8>, tensor<256x256xi8>) outs(%6 : tensor<256x256xi32>) -> tensor<256x256xi32> flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xi32> -> !flow.dispatch.tensor> return } @@ -260,13 +245,7 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { // Make sure it generates the mfma instructions we expect for integer inputs. -// CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info, -// CHECK-SAME: subgroup_m_count = 2, subgroup_n_count = 2> - // CHECK-LABEL: func.func @matmul_256x256x256_i8_i32() -// CHECK-SAME: translation_info = #[[$TRANSLATION]] // Each subgroup handles 2 * 2 tiles, and for each tile we accumulate 8 times // along the K dimension. So in total 32 mfma ops. // CHECK-COUNT-32: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<8xi8>, vector<8xi8>, vector<4xi32> @@ -276,6 +255,9 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { // Basic i8, i8 -> i32 matmul_transpose_b. +#config = #iree_codegen.lowering_config +#translation = #iree_codegen.translation_info, mma_schedule = #iree_gpu.mma_schedule, subgroup_m_count = 2, subgroup_n_count = 2>}> + #pipeline_layout = #hal.pipeline.layout, #hal.pipeline.binding, @@ -289,7 +271,7 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { hal.return %x, %y, %z : index, index, index } builtin.module { - func.func @matmul_transpose_b_256x256x256_i8_i32() { + func.func @matmul_transpose_b_256x256x256_i8_i32() attributes {translation_info = #translation} { %cst = arith.constant 0 : i32 %c0 = arith.constant 0 : index %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> @@ -298,8 +280,8 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<256x256xi8> %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<256x256xi8> %5 = tensor.empty() : tensor<256x256xi32> - %6 = linalg.fill ins(%cst : i32) outs(%5 : tensor<256x256xi32>) -> tensor<256x256xi32> - %7 = linalg.matmul_transpose_b ins(%3, %4 : tensor<256x256xi8>, tensor<256x256xi8>) outs(%6 : tensor<256x256xi32>) -> tensor<256x256xi32> + %6 = linalg.fill {lowering_config = #config} ins(%cst : i32) outs(%5 : tensor<256x256xi32>) -> tensor<256x256xi32> + %7 = linalg.matmul_transpose_b {lowering_config = #config} ins(%3, %4 : tensor<256x256xi8>, tensor<256x256xi8>) outs(%6 : tensor<256x256xi32>) -> tensor<256x256xi32> flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xi32> -> !flow.dispatch.tensor> return } @@ -309,13 +291,7 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { // Make sure it generates the mfma instructions we expect for integer inputs. -// CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info, -// CHECK-SAME: subgroup_m_count = 2, subgroup_n_count = 2> - // CHECK-LABEL: func.func @matmul_transpose_b_256x256x256_i8_i32() -// CHECK-SAME: translation_info = #[[$TRANSLATION]] // Each subgroup handles 2 * 2 tiles, and for each tile we accumulate 8 times // along the K dimension. So in total 32 mfma ops. // CHECK-COUNT-32: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<8xi8>, vector<8xi8>, vector<4xi32> @@ -323,6 +299,9 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { // ----- +#config = #iree_codegen.lowering_config +#translation = #iree_codegen.translation_info, mma_schedule = #iree_gpu.mma_schedule, subgroup_m_count = 2, subgroup_n_count = 2>}> + #pipeline_layout = #hal.pipeline.layout, #hal.pipeline.binding, @@ -336,7 +315,7 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { hal.return %x, %y, %z : index, index, index } builtin.module { - func.func @conv_nhwc() { + func.func @conv_nhwc() attributes {translation_info = #translation} { %c0 = arith.constant 0 : index %cst = arith.constant 0.000000e+00 : f32 %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor> @@ -345,8 +324,8 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [2, 258, 514, 768], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<2x258x514x768xf16> %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [3, 3, 768, 256], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<3x3x768x256xf16> %5 = tensor.empty() : tensor<2x256x512x256xf32> - %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<2x256x512x256xf32>) -> tensor<2x256x512x256xf32> - %7 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%3, %4 : tensor<2x258x514x768xf16>, tensor<3x3x768x256xf16>) outs(%6 : tensor<2x256x512x256xf32>) -> tensor<2x256x512x256xf32> + %6 = linalg.fill {lowering_config = #config} ins(%cst : f32) outs(%5 : tensor<2x256x512x256xf32>) -> tensor<2x256x512x256xf32> + %7 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>, lowering_config = #config} ins(%3, %4 : tensor<2x258x514x768xf16>, tensor<3x3x768x256xf16>) outs(%6 : tensor<2x256x512x256xf32>) -> tensor<2x256x512x256xf32> flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0, 0], sizes = [2, 256, 512, 256], strides = [1, 1, 1, 1] : tensor<2x256x512x256xf32> -> !flow.dispatch.tensor> return } @@ -363,6 +342,9 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { // ----- +#config = #iree_codegen.lowering_config +#translation = #iree_codegen.translation_info, mma_schedule = #iree_gpu.mma_schedule, subgroup_m_count = 2, subgroup_n_count = 2>}> + #pipeline_layout = #hal.pipeline.layout, #hal.pipeline.binding, @@ -380,7 +362,7 @@ hal.executable public @main_dispatch_expanded_matmul { hal.return %x, %y, %z : index, index, index } builtin.module { - func.func @generic_2x1024x20x64x1280_f16() { + func.func @generic_2x1024x20x64x1280_f16() attributes {translation_info = #translation} { %cst = arith.constant 0.000000e+00 : f16 %c0 = arith.constant 0 : index %0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : i32 @@ -393,10 +375,11 @@ hal.executable public @main_dispatch_expanded_matmul { %7 = flow.dispatch.tensor.load %4, offsets = [0, 0, 0], sizes = [2, 1024, 1280], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<2x1024x1280xf16> %8 = flow.dispatch.tensor.load %5, offsets = [0, 0, 0], sizes = [20, 64, 1280], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<20x64x1280xf16> %9 = tensor.empty() : tensor<2x1024x20x64xf16> - %10 = linalg.fill ins(%cst : f16) outs(%9 : tensor<2x1024x20x64xf16>) -> tensor<2x1024x20x64xf16> + %10 = linalg.fill {lowering_config = #config} ins(%cst : f16) outs(%9 : tensor<2x1024x20x64xf16>) -> tensor<2x1024x20x64xf16> %11 = linalg.generic { indexing_maps = [#map, #map1, #map2], - iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"] + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"], + lowering_config = #config } ins(%7, %8 : tensor<2x1024x1280xf16>, tensor<20x64x1280xf16>) outs(%10 : tensor<2x1024x20x64xf16>) { ^bb0(%in: f16, %in_0: f16, %out: f16): @@ -412,11 +395,6 @@ hal.executable public @main_dispatch_expanded_matmul { } -// CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info, -// CHECK-SAME: subgroup_m_count = 2, subgroup_n_count = 2> - // CHECK-LABEL: func.func @generic_2x1024x20x64x1280_f16 // This has more than 2 iteartions. So we have prefetching enabled for this case. Due to // prefetching, we have one iteration peeled of so upper bound is 1280 - 128 = 1152. @@ -430,105 +408,8 @@ hal.executable public @main_dispatch_expanded_matmul { // ----- -#pipeline_layout = #hal.pipeline.layout, - #hal.pipeline.binding, - #hal.pipeline.binding -]> -hal.executable @matmul_256x256x256_f16_f32 { -hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { - hal.executable.export @matmul_256x256x256_f16_f32 layout(#pipeline_layout) { - ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index): - %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2 - hal.return %x, %y, %z : index, index, index - } - builtin.module { - func.func @matmul_256x256x256_f16_f32() { - %cst = arith.constant 0.000000e+00 : f32 - %c0 = arith.constant 0 : index - %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> - %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> - %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<256x256xf16> - %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<256x256xf16> - %5 = tensor.empty() : tensor<256x256xf32> - %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<256x256xf32>) -> tensor<256x256xf32> - %7 = linalg.matmul ins(%3, %4 : tensor<256x256xf16>, tensor<256x256xf16>) outs(%6 : tensor<256x256xf32>) -> tensor<256x256xf32> - flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xf32> -> !flow.dispatch.tensor> - return - } - } -} -} - -// RDNA3: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info, -// RDNA3-SAME: subgroup_m_count = 2, subgroup_n_count = 2> - -// RDNA3-LABEL: func.func @matmul_256x256x256_f16_f32 -// RDNA3-SAME: translation_info = #[[$TRANSLATION]] -// RDNA3: scf.for {{.*}} = %c0 to %c256 step %c128 iter_args({{.*}}) -> (vector<2x2x8x1x1x1xf32>) -// Each subgroup handles 2 * 2 tiles, and for each tile we accumulate 8 times -// along the K dimension. So in total 32 wmma ops. -// RDNA3-COUNT-32: amdgpu.wmma {{.*}} : vector<16xf16>, vector<16xf16>, vector<8xf32> -// RDNA3: scf.yield %{{.+}} : vector<2x2x8x1x1x1xf32> -// Since each subgroup handles 2 * 2 tiles, and for each tile, each lane holds 4 values. -// we will have 32 writes. We cannot do contiguous writes since the outputs columns has interleaved -// thread ids. -// RDNA3-COUNT-32: vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<1x1xf32>, memref<256x256xf32, #hal.descriptor_type> - -// ----- - -#pipeline_layout = #hal.pipeline.layout, - #hal.pipeline.binding, - #hal.pipeline.binding -]> -hal.executable @matmul_256x256x256_f16_f16 { -hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { - hal.executable.export @matmul_256x256x256_f16_f16 layout(#pipeline_layout) { - ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index): - %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2 - hal.return %x, %y, %z : index, index, index - } - builtin.module { - func.func @matmul_256x256x256_f16_f16() { - %cst = arith.constant 0.000000e+00 : f16 - %c0 = arith.constant 0 : index - %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> - %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> - %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<256x256xf16> - %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<256x256xf16> - %5 = tensor.empty() : tensor<256x256xf16> - %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<256x256xf16>) -> tensor<256x256xf16> - %7 = linalg.matmul ins(%3, %4 : tensor<256x256xf16>, tensor<256x256xf16>) outs(%6 : tensor<256x256xf16>) -> tensor<256x256xf16> - flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xf16> -> !flow.dispatch.tensor> - return - } - } -} -} - -// RDNA3: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info, -// RDNA3-SAME: subgroup_m_count = 2, subgroup_n_count = 2> - -// RDNA3-LABEL: func.func @matmul_256x256x256_f16_f16 -// RDNA3-SAME: translation_info = #[[$TRANSLATION]] -// RDNA3: scf.for {{.*}} = %c0 to %c256 step %c128 iter_args({{.*}}) -> (vector<2x2x16x1x1x1xf16>) -// Each subgroup handles 2 * 2 tiles, and for each tile we accumulate 8 times -// along the K dimension. So in total 32 wmma ops. -// RDNA3-COUNT-32: amdgpu.wmma {{.*}} : vector<16xf16>, vector<16xf16>, vector<16xf16> -// RDNA3: scf.yield %{{.+}} : vector<2x2x16x1x1x1xf16> -// Since each subgroup handles 2 * 2 tiles, and for each tile, each lane holds 4 values. -// we will have 32 writes. We cannot do contiguous writes since the outputs columns has interleaved -// thread ids. -// RDNA3-COUNT-32: vector.transfer_write {{.+}} {in_bounds = [true, true]} : vector<1x1xf16>, memref<256x256xf16, #hal.descriptor_type> - -// ----- +#config = #iree_codegen.lowering_config +#translation = #iree_codegen.translation_info, mma_schedule = #iree_gpu.mma_schedule, subgroup_m_count = 1, subgroup_n_count = 1>}> #pipeline_layout = #hal.pipeline.layout, @@ -543,7 +424,7 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { hal.return %x, %y, %z : index, index, index } builtin.module { - func.func @unaligned_nk_batch_matmul() { + func.func @unaligned_nk_batch_matmul() attributes {translation_info = #translation} { %cst = arith.constant 0.000000e+00 : f16 %c0 = arith.constant 0 : index %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> @@ -552,8 +433,8 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [64, 968, 1281], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<64x968x1281xf16> %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [64, 1281, 1281], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<64x1281x1281xf16> %5 = tensor.empty() : tensor<64x968x1281xf16> - %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<64x968x1281xf16>) -> tensor<64x968x1281xf16> - %7 = linalg.batch_matmul ins(%3, %4 : tensor<64x968x1281xf16>, tensor<64x1281x1281xf16>) outs(%6 : tensor<64x968x1281xf16>) -> tensor<64x968x1281xf16> + %6 = linalg.fill {lowering_config = #config} ins(%cst : f16) outs(%5 : tensor<64x968x1281xf16>) -> tensor<64x968x1281xf16> + %7 = linalg.batch_matmul {lowering_config = #config} ins(%3, %4 : tensor<64x968x1281xf16>, tensor<64x1281x1281xf16>) outs(%6 : tensor<64x968x1281xf16>) -> tensor<64x968x1281xf16> flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0], sizes = [64, 968, 1281], strides = [1, 1, 1] : tensor<64x968x1281xf16> -> !flow.dispatch.tensor> return } @@ -562,13 +443,7 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { } // Basic pipeline test to make sure it generates the instructions we expect. -// CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info, -// CHECK-SAME: subgroup_m_count = 1, subgroup_n_count = 1> - // CHECK-LABEL: func.func @unaligned_nk_batch_matmul() -// CHECK-SAME: translation_info = #[[$TRANSLATION]] // CHECK-DAG: %[[RHS_SHARED:.+]] = memref.alloc() : memref<1x16x20xf16, #gpu.address_space> // CHECK-DAG: %[[RHS_SHARED_SUB:.+]] = memref.subview %[[RHS_SHARED]][0, 0, 0] [1, 16, 16] [1, 1, 1] // CHECK-DAG: %[[LHS_SHARED:.+]] = memref.alloc() : memref<1x16x20xf16, #gpu.address_space> @@ -612,6 +487,9 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) { // NOTE: This test is not exhaustive of all possible ways the above condition is breaking, // but rather is an example of a matmul shape from a model that broke our compilation heuristic. +#config = #iree_codegen.lowering_config +#translation = #iree_codegen.translation_info, mma_schedule = #iree_gpu.mma_schedule, subgroup_m_count = 1, subgroup_n_count = 4>}> + #pipeline_layout = #hal.pipeline.layout, #hal.pipeline.binding @@ -624,7 +502,7 @@ hal.executable public @contract_schedule_considering_read_layout { hal.return %x, %y, %z : index, index, index } builtin.module { - func.func @contract_schedule_considering_read_layout() { + func.func @contract_schedule_considering_read_layout() attributes {translation_info = #translation} { %cst = arith.constant 0.000000e+00 : f16 %0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : i32 %1 = hal.interface.constant.load layout(#pipeline_layout) ordinal(1) : i32 @@ -638,8 +516,8 @@ hal.executable public @contract_schedule_considering_read_layout { %9 = flow.dispatch.tensor.load %6, offsets = [0, 0, 0], sizes = [2, 160, 1536], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<2x160x1536xf16> %10 = flow.dispatch.tensor.load %7, offsets = [0, 0, 0], sizes = [2, 1536, 1536], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<2x1536x1536xf16> %11 = tensor.empty() : tensor<2x160x1536xf16> - %12 = linalg.fill ins(%cst : f16) outs(%11 : tensor<2x160x1536xf16>) -> tensor<2x160x1536xf16> - %13 = linalg.batch_matmul ins(%9, %10 : tensor<2x160x1536xf16>, tensor<2x1536x1536xf16>) outs(%12 : tensor<2x160x1536xf16>) -> tensor<2x160x1536xf16> + %12 = linalg.fill {lowering_config = #config} ins(%cst : f16) outs(%11 : tensor<2x160x1536xf16>) -> tensor<2x160x1536xf16> + %13 = linalg.batch_matmul {lowering_config = #config} ins(%9, %10 : tensor<2x160x1536xf16>, tensor<2x1536x1536xf16>) outs(%12 : tensor<2x160x1536xf16>) -> tensor<2x160x1536xf16> flow.dispatch.tensor.store %13, %8, offsets = [0, 0, 0], sizes = [2, 160, 1536], strides = [1, 1, 1] : tensor<2x160x1536xf16> -> !flow.dispatch.tensor> return } @@ -648,13 +526,7 @@ hal.executable public @contract_schedule_considering_read_layout { } // Basic pipeline test to make sure it generates the instructions we expect. -// CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info, -// CHECK-SAME: subgroup_m_count = 1, subgroup_n_count = 4> - // CHECK-LABEL: func.func @contract_schedule_considering_read_layout() -// CHECK-SAME: translation_info = #[[$TRANSLATION]] // CHECK-DAG: %[[RHS_SHARED:.+]] = memref.alloc() : memref<128x132xf16, #gpu.address_space> // CHECK-DAG: %[[RHS_SHARED_SUB:.+]] = memref.subview %[[RHS_SHARED]][0, 0] [128, 128] [1, 1] // CHECK-DAG: %[[LHS_SHARED:.+]] = memref.alloc() : memref<16x132xf16, #gpu.address_space> @@ -666,6 +538,9 @@ hal.executable public @contract_schedule_considering_read_layout { // ----- +#config = #iree_codegen.lowering_config +#translation = #iree_codegen.translation_info, subgroup_m_count = 2, subgroup_n_count = 1>}> + #pipeline_layout = #hal.pipeline.layout, #hal.pipeline.binding, @@ -680,7 +555,7 @@ hal.executable private @attention_20x4096x64x4096x64 { hal.return %x, %y, %z : index, index, index } builtin.module { - func.func @attention_20x4096x64x4096x64() { + func.func @attention_20x4096x64x4096x64() attributes {translation_info = #translation} { %cst = arith.constant 1.250000e-01 : f16 %c0 = arith.constant 0 : index %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> @@ -694,7 +569,8 @@ hal.executable private @attention_20x4096x64x4096x64 { %8 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, - affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>], + lowering_config = #config} ins(%4, %5, %6, %cst : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16) outs(%7 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16> flow.dispatch.tensor.store %8, %3, offsets = [0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : tensor<20x4096x64xf16> -> !flow.dispatch.tensor> return @@ -705,14 +581,7 @@ hal.executable private @attention_20x4096x64x4096x64 { // Basic test to make sure we can handle attention -// CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info, -// CHECK-SAME: subgroup_m_count = 2, subgroup_n_count = 1> -// Prefetching is disabled for attention for now -// CHECK-NOT: gpu_pipeline_options = #iree_gpu.pipeline_options +#translation = #iree_codegen.translation_info, subgroup_m_count = 2, subgroup_n_count = 1>}> + #pipeline_layout = #hal.pipeline.layout, #hal.pipeline.binding, @@ -742,7 +614,7 @@ hal.executable private @attention_multiple_m_transpose { hal.return %x, %y, %z : index, index, index } builtin.module { - func.func @attention_multiple_m_transpose() { + func.func @attention_multiple_m_transpose() attributes {translation_info = #translation} { %cst = arith.constant 1.0 : f16 %c0 = arith.constant 0 : index %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor> @@ -754,8 +626,8 @@ hal.executable private @attention_multiple_m_transpose { %6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [24, 4608, 128], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<24x4608x128xf16> %7 = tensor.empty() : tensor<64x4608x24x128xf16> %8 = tensor.empty() : tensor<24x64x4608x128xf16> - %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>]} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) -> tensor<24x64x4608x128xf16> - %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2, d0, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%9 : tensor<24x64x4608x128xf16>) outs(%7 : tensor<64x4608x24x128xf16>) { + %9 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>], lowering_config = #config} ins(%4, %5, %6, %cst : tensor<24x64x4608x128xf16>, tensor<24x4608x128xf16>, tensor<24x4608x128xf16>, f16) outs(%8 : tensor<24x64x4608x128xf16>) -> tensor<24x64x4608x128xf16> + %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2, d0, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], lowering_config = #config} ins(%9 : tensor<24x64x4608x128xf16>) outs(%7 : tensor<64x4608x24x128xf16>) { ^bb0(%in: f16, %out: f16): linalg.yield %in : f16 } -> tensor<64x4608x24x128xf16> @@ -766,13 +638,7 @@ hal.executable private @attention_multiple_m_transpose { } } -// CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info, -// CHECK-SAME: subgroup_m_count = 2, subgroup_n_count = 1> - // CHECK-LABEL: func.func @attention_multiple_m_transpose() -// CHECK-SAME: translation_info = #[[$TRANSLATION]] - // CHECK: scf.for %{{.*}} = %c0 to %c72 step %c1 // CHECK-SAME: -> (vector<2x1x4xf32>, vector<2x1x4xf32>, vector<2x8x1x1x4x1xf32>) // CHECK-COUNT-96: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>