diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp index 4b3ae90362b9..b953d78f9b4d 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp @@ -750,7 +750,7 @@ struct DistributeBroadcastLayoutAttr final /// sequence of multiplications and additions. /// struct DistributeLayoutConflictResolutions final - : OpDistributionPattern { + : OpDistributionPattern { using OpDistributionPattern::OpDistributionPattern; VectorValue reshapeVector(Location loc, RewriterBase &rewriter, @@ -792,10 +792,9 @@ struct DistributeLayoutConflictResolutions final return newVector; } - LogicalResult - matchAndRewrite(IREE::VectorExt::LayoutConflictResolutionOp resolutionOp, - DistributionSignature &signature, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(IREE::VectorExt::ToLayoutOp resolutionOp, + DistributionSignature &signature, + PatternRewriter &rewriter) const override { VectorValue vector = resolutionOp.getInput(); VectorValue result = resolutionOp.getOutput(); LayoutAttr currentLayout = dyn_cast(signature[vector]); @@ -837,13 +836,12 @@ struct DistributeLayoutConflictResolutions final /// especially used when we don't have an optimized way /// to resolve the conflict. struct DistributeLayoutConflictToSharedMemory final - : OpDistributionPattern { + : OpDistributionPattern { using OpDistributionPattern::OpDistributionPattern; - LogicalResult - matchAndRewrite(IREE::VectorExt::LayoutConflictResolutionOp resolutionOp, - DistributionSignature &signature, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(IREE::VectorExt::ToLayoutOp resolutionOp, + DistributionSignature &signature, + PatternRewriter &rewriter) const override { auto loc = resolutionOp.getLoc(); VectorValue vector = resolutionOp.getInput(); VectorValue result = resolutionOp.getOutput(); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir index e5ecf598359b..2654eaede125 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_vector_distribution.mlir @@ -643,7 +643,7 @@ func.func @unresolved_layout_conflict(%a : memref<32x16xf16>, %b : memref<32x16x %vcst = arith.constant dense<0.0> : vector<32x16xf16> // CHECK-COUNT-8: vector.load %[[MEM]] %vec = vector.transfer_read %a[%c0, %c0], %cst {"__vector_layout_test_anchor_result_0" = #layout1} : memref<32x16xf16>, vector<32x16xf16> - // CHECK: iree_vector_ext.layout_conflict_resolution {{.*}} + // CHECK: iree_vector_ext.to_layout {{.*}} %vec2 = arith.addf %vec, %vcst : vector<32x16xf16> // CHECK-COUNT-16: vector.store {{.*}}, vector<1xf16> vector.transfer_write %vec2, %b[%c0, %c0] {in_bounds = [true, true], diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp index a8aae488f14a..69c7ff978e19 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp @@ -1152,7 +1152,7 @@ static void emitLayoutRemarks(VectorLayoutAnalysis &analysis, mlir::FunctionOpInterface funcOp) { funcOp.walk([&](Operation *op) { // Do not emit remarks for conflict operations. - if (isa(op)) { + if (isa(op)) { return; } diff --git a/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp b/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp index 394d7403e79a..98ad063a0786 100644 --- a/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp @@ -217,9 +217,8 @@ ChangeResult DistributionLayout::resolveWithPossibleConflict( Value input = opOperand.get(); // Create a resolution operation. This conflict should be handeled later by // someone else, not this analysis. - Operation *resolveOp = - builder.create( - input.getLoc(), input.getType(), input, vectorLayout, rhs); + Operation *resolveOp = builder.create( + input.getLoc(), input.getType(), input, rhs); Value resolvedValue = resolveOp->getResult(0); opOperand.set(resolvedValue); @@ -1015,9 +1014,9 @@ void VectorLayoutAnalysis::debugAnnotateLayouts() { continue; } - // Do not annotate resolve_conflict operations since they already have + // Do not annotate to_layout operations since they already have // this information in their attributes. - if (isa(op)) { + if (isa(op)) { continue; } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp index 6c9f90a34c7a..77d0e95b7475 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp @@ -1676,10 +1676,8 @@ transform_dialect::SetContractionLayoutAttributes::apply( if (!parentOp || (parentOp->getNumResults() != 1)) continue; parentOp->setAttr("__vector_layout_test_anchor_result_0", readLayout); - Value resolvedOperand = - rewriter.create( - contract.getLoc(), operand.getType(), operand, layout, - readLayout); + Value resolvedOperand = rewriter.create( + contract.getLoc(), operand.getType(), operand, layout); contract.setOperand(operandIndices[i], resolvedOperand); } } diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.td index fdaecc6872c6..f21607a25a2c 100644 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.td +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtOps.td @@ -24,23 +24,24 @@ class IREEVectorExt_PureOp traits = []> : // Layout ops. //===----------------------------------------------------------------------===// -def IREEVectorExt_LayoutConflictResolutionOp : IREEVectorExt_PureOp<"layout_conflict_resolution"> { - let summary = "Layout Conflict Resolution operator"; +def IREEVectorExt_ToLayoutOp : IREEVectorExt_PureOp<"to_layout", [ + Pure, + AllTypesMatch<["input", "output"]> + ]> { + let summary = "Layout conversion operator"; let description = [{ - The layout conflict resolution operator takes a vector and a - desired layout and transforms the vector to one with the - desired layout. + The layout conversion operator takes a shaped value and a layout and + transforms the value to have that layout. }]; let arguments = (ins AnyVector:$input, - VectorLayoutInterface:$sourceLayout, - VectorLayoutInterface:$desiredLayout + VectorLayoutInterface:$layout ); let results = (outs AnyVector:$output ); let extraClassDeclaration = [{}]; - let assemblyFormat = "$input attr-dict `:` type($input) `->` type($output)"; + let assemblyFormat = "$input `to` $layout attr-dict `:` type($input)"; let hasVerifier = 1; } diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp index dd52c2ce41fa..c6943989c515 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtOps.cpp @@ -17,15 +17,9 @@ using VectorValue = TypedValue; // LayoutConflictResolutionOp //===----------------------------------------------------------------------===// -// Validate that the desired layout has the same shape as the input. -LogicalResult LayoutConflictResolutionOp::verify() { - if (getSourceLayout().isValidLayout(getInput()).failed()) { - return failure(); - } - if (getDesiredLayout().isValidLayout(getOutput()).failed()) { - return failure(); - } - return success(); +// Validate that the layout has the same shape as the input. +LogicalResult ToLayoutOp::verify() { + return getLayout().isValidLayout(getInput()); } // to_simd -> to_simt diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/invalid.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/invalid.mlir index 739f3d434b93..7a848362524f 100644 --- a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/invalid.mlir +++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/invalid.mlir @@ -3,28 +3,12 @@ #row_layout1 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX, VECTORY], [1, 1, 1]> #col_layout1 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [4, 2, 4]> #layout1 = #iree_vector_ext.layout<#row_layout1, #col_layout1> -#layout2 = #iree_vector_ext.layout<#col_layout1, #col_layout1> -func.func @invalid_desired_layout(%lhs: memref<32x32xf16>, %rhs: memref<32x32xf16>) -> vector<32x32xf16> { +func.func @invalid_layout(%lhs: memref<32x32xf16>, %rhs: memref<32x32xf16>) -> vector<32x32xf16> { %cst_0 = arith.constant 0.0 : f16 %c0 = arith.constant 0 : index - %result = vector.transfer_read %lhs[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<32x32xf16>, vector<32x32xf16> // expected-error @+1 {{Vector shape: [32, 32] does not match the layout (layout<<[ BATCHX, LANEX, VECTORY], [1, 1, 1]>, <[ BATCHY, LANEY, VECTORX], [4, 2, 4]>>) at dim 0. Dimension expected by layout: 1 actual: 32}} - %2 = iree_vector_ext.layout_conflict_resolution %result {desiredLayout = #layout1, sourceLayout = #layout2} : vector<32x32xf16> -> vector<32x32xf16> - return %2 : vector<32x32xf16> -} - -// ----- - -#row_layout1 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX, VECTORY], [1, 1, 1]> -#col_layout1 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [4, 2, 4]> -#layout1 = #iree_vector_ext.layout<#row_layout1, #col_layout1> -#layout2 = #iree_vector_ext.layout<#col_layout1, #col_layout1> -func.func @invalid_source_layout(%lhs: memref<32x32xf16>, %rhs: memref<32x32xf16>) -> vector<32x32xf16> { - %cst_0 = arith.constant 0.0 : f16 - %c0 = arith.constant 0 : index %result = vector.transfer_read %lhs[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<32x32xf16>, vector<32x32xf16> - // expected-error @-1 {{Vector shape: [32, 32] does not match the layout (layout<<[ BATCHX, LANEX, VECTORY], [1, 1, 1]>, <[ BATCHY, LANEY, VECTORX], [4, 2, 4]>>) at dim 0. Dimension expected by layout: 1 actual: 32}} - %2 = iree_vector_ext.layout_conflict_resolution %result {desiredLayout = #layout2, sourceLayout = #layout1} : vector<32x32xf16> -> vector<32x32xf16> + %2 = iree_vector_ext.to_layout %result to #layout1 : vector<32x32xf16> return %2 : vector<32x32xf16> } diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir index 49fe27c2f86f..5d8018df03f6 100644 --- a/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir +++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/roundtrip.mlir @@ -2,22 +2,18 @@ #row_layout1 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX, VECTORY], [2, 4, 4]> #col_layout1 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [4, 2, 4]> -#layout1 = #iree_vector_ext.layout<#row_layout1, #col_layout1> #layout2 = #iree_vector_ext.layout<#col_layout1, #row_layout1> func.func @specify_layout(%lhs: memref<32x32xf16>) -> vector<32x32xf16> { %cst_0 = arith.constant 0.0 : f16 %c0 = arith.constant 0 : index %result = vector.transfer_read %lhs[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<32x32xf16>, vector<32x32xf16> - %2 = iree_vector_ext.layout_conflict_resolution %result {sourceLayout = #layout1, desiredLayout = #layout2} : vector<32x32xf16> -> vector<32x32xf16> + %2 = iree_vector_ext.to_layout %result to #layout2 : vector<32x32xf16> return %2 : vector<32x32xf16> } // CHECK-DAG: #[[LAYOUT0:.+]] = #iree_vector_ext.layout<<[ BATCHY, LANEY, VECTORX], [4, 2, 4]>, <[ BATCHX, LANEX, VECTORY], [2, 4, 4]>> -// CHECK-DAG: #[[LAYOUT1:.+]] = #iree_vector_ext.layout<<[ BATCHX, LANEX, VECTORY], [2, 4, 4]>, <[ BATCHY, LANEY, VECTORX], [4, 2, 4]>> // CHECK-LABEL: func.func @specify_layout -// CHECK: iree_vector_ext.layout_conflict_resolution -// CHECK-SAME: desiredLayout = #[[LAYOUT0]] -// CHECK-SAME: sourceLayout = #[[LAYOUT1]] +// CHECK: iree_vector_ext.to_layout {{.*}} to #[[LAYOUT0]] // -----