Skip to content

Commit

Permalink
[VectorDistribution] Replace layout_resolution with to_layout (#18027)
Browse files Browse the repository at this point in the history
This patch replaces the layout_resolution operator with a new
"to_layout" operation, representing a layout cast on the result. This
allows the operation to be used as an anchor and a conversion operation.

This operation will be used in later patches to set layout anchors in IR
and preserve them.
  • Loading branch information
Groverkss committed Jul 30, 2024
1 parent 998ed49 commit 18c183f
Show file tree
Hide file tree
Showing 9 changed files with 32 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,7 @@ struct DistributeBroadcastLayoutAttr final
/// sequence of multiplications and additions.
///
struct DistributeLayoutConflictResolutions final
: OpDistributionPattern<IREE::VectorExt::LayoutConflictResolutionOp> {
: OpDistributionPattern<IREE::VectorExt::ToLayoutOp> {
using OpDistributionPattern::OpDistributionPattern;

VectorValue reshapeVector(Location loc, RewriterBase &rewriter,
Expand Down Expand Up @@ -792,10 +792,9 @@ struct DistributeLayoutConflictResolutions final
return newVector;
}

LogicalResult
matchAndRewrite(IREE::VectorExt::LayoutConflictResolutionOp resolutionOp,
DistributionSignature &signature,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(IREE::VectorExt::ToLayoutOp resolutionOp,
DistributionSignature &signature,
PatternRewriter &rewriter) const override {
VectorValue vector = resolutionOp.getInput();
VectorValue result = resolutionOp.getOutput();
LayoutAttr currentLayout = dyn_cast<LayoutAttr>(signature[vector]);
Expand Down Expand Up @@ -837,13 +836,12 @@ struct DistributeLayoutConflictResolutions final
/// especially used when we don't have an optimized way
/// to resolve the conflict.
struct DistributeLayoutConflictToSharedMemory final
: OpDistributionPattern<IREE::VectorExt::LayoutConflictResolutionOp> {
: OpDistributionPattern<IREE::VectorExt::ToLayoutOp> {
using OpDistributionPattern::OpDistributionPattern;

LogicalResult
matchAndRewrite(IREE::VectorExt::LayoutConflictResolutionOp resolutionOp,
DistributionSignature &signature,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(IREE::VectorExt::ToLayoutOp resolutionOp,
DistributionSignature &signature,
PatternRewriter &rewriter) const override {
auto loc = resolutionOp.getLoc();
VectorValue vector = resolutionOp.getInput();
VectorValue result = resolutionOp.getOutput();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@ func.func @unresolved_layout_conflict(%a : memref<32x16xf16>, %b : memref<32x16x
%vcst = arith.constant dense<0.0> : vector<32x16xf16>
// CHECK-COUNT-8: vector.load %[[MEM]]
%vec = vector.transfer_read %a[%c0, %c0], %cst {"__vector_layout_test_anchor_result_0" = #layout1} : memref<32x16xf16>, vector<32x16xf16>
// CHECK: iree_vector_ext.layout_conflict_resolution {{.*}}
// CHECK: iree_vector_ext.to_layout {{.*}}
%vec2 = arith.addf %vec, %vcst : vector<32x16xf16>
// CHECK-COUNT-16: vector.store {{.*}}, vector<1xf16>
vector.transfer_write %vec2, %b[%c0, %c0] {in_bounds = [true, true],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1152,7 +1152,7 @@ static void emitLayoutRemarks(VectorLayoutAnalysis &analysis,
mlir::FunctionOpInterface funcOp) {
funcOp.walk([&](Operation *op) {
// Do not emit remarks for conflict operations.
if (isa<VectorExt::LayoutConflictResolutionOp>(op)) {
if (isa<VectorExt::ToLayoutOp>(op)) {
return;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,8 @@ ChangeResult DistributionLayout::resolveWithPossibleConflict(
Value input = opOperand.get();
// Create a resolution operation. This conflict should be handeled later by
// someone else, not this analysis.
Operation *resolveOp =
builder.create<IREE ::VectorExt::LayoutConflictResolutionOp>(
input.getLoc(), input.getType(), input, vectorLayout, rhs);
Operation *resolveOp = builder.create<IREE::VectorExt::ToLayoutOp>(
input.getLoc(), input.getType(), input, rhs);
Value resolvedValue = resolveOp->getResult(0);
opOperand.set(resolvedValue);

Expand Down Expand Up @@ -1015,9 +1014,9 @@ void VectorLayoutAnalysis::debugAnnotateLayouts() {
continue;
}

// Do not annotate resolve_conflict operations since they already have
// Do not annotate to_layout operations since they already have
// this information in their attributes.
if (isa<IREE::VectorExt::LayoutConflictResolutionOp>(op)) {
if (isa<IREE::VectorExt::ToLayoutOp>(op)) {
continue;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1676,10 +1676,8 @@ transform_dialect::SetContractionLayoutAttributes::apply(
if (!parentOp || (parentOp->getNumResults() != 1))
continue;
parentOp->setAttr("__vector_layout_test_anchor_result_0", readLayout);
Value resolvedOperand =
rewriter.create<VectorExt::LayoutConflictResolutionOp>(
contract.getLoc(), operand.getType(), operand, layout,
readLayout);
Value resolvedOperand = rewriter.create<VectorExt::ToLayoutOp>(
contract.getLoc(), operand.getType(), operand, layout);
contract.setOperand(operandIndices[i], resolvedOperand);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,24 @@ class IREEVectorExt_PureOp<string mnemonic, list<Trait> traits = []> :
// Layout ops.
//===----------------------------------------------------------------------===//

def IREEVectorExt_LayoutConflictResolutionOp : IREEVectorExt_PureOp<"layout_conflict_resolution"> {
let summary = "Layout Conflict Resolution operator";
def IREEVectorExt_ToLayoutOp : IREEVectorExt_PureOp<"to_layout", [
Pure,
AllTypesMatch<["input", "output"]>
]> {
let summary = "Layout conversion operator";
let description = [{
The layout conflict resolution operator takes a vector and a
desired layout and transforms the vector to one with the
desired layout.
The layout conversion operator takes a shaped value and a layout and
transforms the value to have that layout.
}];
let arguments = (ins
AnyVector:$input,
VectorLayoutInterface:$sourceLayout,
VectorLayoutInterface:$desiredLayout
VectorLayoutInterface:$layout
);
let results = (outs
AnyVector:$output
);
let extraClassDeclaration = [{}];
let assemblyFormat = "$input attr-dict `:` type($input) `->` type($output)";
let assemblyFormat = "$input `to` $layout attr-dict `:` type($input)";
let hasVerifier = 1;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,9 @@ using VectorValue = TypedValue<VectorType>;
// LayoutConflictResolutionOp
//===----------------------------------------------------------------------===//

// Validate that the desired layout has the same shape as the input.
LogicalResult LayoutConflictResolutionOp::verify() {
if (getSourceLayout().isValidLayout(getInput()).failed()) {
return failure();
}
if (getDesiredLayout().isValidLayout(getOutput()).failed()) {
return failure();
}
return success();
// Validate that the layout has the same shape as the input.
LogicalResult ToLayoutOp::verify() {
return getLayout().isValidLayout(getInput());
}

// to_simd -> to_simt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,12 @@
#row_layout1 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX, VECTORY], [1, 1, 1]>
#col_layout1 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [4, 2, 4]>
#layout1 = #iree_vector_ext.layout<#row_layout1, #col_layout1>
#layout2 = #iree_vector_ext.layout<#col_layout1, #col_layout1>
func.func @invalid_desired_layout(%lhs: memref<32x32xf16>, %rhs: memref<32x32xf16>) -> vector<32x32xf16> {
func.func @invalid_layout(%lhs: memref<32x32xf16>, %rhs: memref<32x32xf16>) -> vector<32x32xf16> {
%cst_0 = arith.constant 0.0 : f16
%c0 = arith.constant 0 : index
%result = vector.transfer_read %lhs[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<32x32xf16>, vector<32x32xf16>
// expected-error @+1 {{Vector shape: [32, 32] does not match the layout (layout<<[ BATCHX, LANEX, VECTORY], [1, 1, 1]>, <[ BATCHY, LANEY, VECTORX], [4, 2, 4]>>) at dim 0. Dimension expected by layout: 1 actual: 32}}
%2 = iree_vector_ext.layout_conflict_resolution %result {desiredLayout = #layout1, sourceLayout = #layout2} : vector<32x32xf16> -> vector<32x32xf16>
return %2 : vector<32x32xf16>
}

// -----

#row_layout1 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX, VECTORY], [1, 1, 1]>
#col_layout1 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [4, 2, 4]>
#layout1 = #iree_vector_ext.layout<#row_layout1, #col_layout1>
#layout2 = #iree_vector_ext.layout<#col_layout1, #col_layout1>
func.func @invalid_source_layout(%lhs: memref<32x32xf16>, %rhs: memref<32x32xf16>) -> vector<32x32xf16> {
%cst_0 = arith.constant 0.0 : f16
%c0 = arith.constant 0 : index
%result = vector.transfer_read %lhs[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<32x32xf16>, vector<32x32xf16>
// expected-error @-1 {{Vector shape: [32, 32] does not match the layout (layout<<[ BATCHX, LANEX, VECTORY], [1, 1, 1]>, <[ BATCHY, LANEY, VECTORX], [4, 2, 4]>>) at dim 0. Dimension expected by layout: 1 actual: 32}}
%2 = iree_vector_ext.layout_conflict_resolution %result {desiredLayout = #layout2, sourceLayout = #layout1} : vector<32x32xf16> -> vector<32x32xf16>
%2 = iree_vector_ext.to_layout %result to #layout1 : vector<32x32xf16>
return %2 : vector<32x32xf16>
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,18 @@

#row_layout1 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX, VECTORY], [2, 4, 4]>
#col_layout1 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [4, 2, 4]>
#layout1 = #iree_vector_ext.layout<#row_layout1, #col_layout1>
#layout2 = #iree_vector_ext.layout<#col_layout1, #row_layout1>
func.func @specify_layout(%lhs: memref<32x32xf16>) -> vector<32x32xf16> {
%cst_0 = arith.constant 0.0 : f16
%c0 = arith.constant 0 : index
%result = vector.transfer_read %lhs[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<32x32xf16>, vector<32x32xf16>
%2 = iree_vector_ext.layout_conflict_resolution %result {sourceLayout = #layout1, desiredLayout = #layout2} : vector<32x32xf16> -> vector<32x32xf16>
%2 = iree_vector_ext.to_layout %result to #layout2 : vector<32x32xf16>
return %2 : vector<32x32xf16>
}

// CHECK-DAG: #[[LAYOUT0:.+]] = #iree_vector_ext.layout<<[ BATCHY, LANEY, VECTORX], [4, 2, 4]>, <[ BATCHX, LANEX, VECTORY], [2, 4, 4]>>
// CHECK-DAG: #[[LAYOUT1:.+]] = #iree_vector_ext.layout<<[ BATCHX, LANEX, VECTORY], [2, 4, 4]>, <[ BATCHY, LANEY, VECTORX], [4, 2, 4]>>
// CHECK-LABEL: func.func @specify_layout
// CHECK: iree_vector_ext.layout_conflict_resolution
// CHECK-SAME: desiredLayout = #[[LAYOUT0]]
// CHECK-SAME: sourceLayout = #[[LAYOUT1]]
// CHECK: iree_vector_ext.to_layout {{.*}} to #[[LAYOUT0]]

// -----

Expand Down

0 comments on commit 18c183f

Please sign in to comment.