Skip to content

Commit

Permalink
Fix dominance issue.
Browse files Browse the repository at this point in the history
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
  • Loading branch information
MaheshRavishankar committed Aug 13, 2024
1 parent 6669bdf commit e187f32
Showing 1 changed file with 39 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,30 @@ static bool isEmptyFillContractionDAGRootOp(
return true;
}

/// Check that a given operation is "horizontal" to the group. The operation
/// is horizontal if the `slice` of the operation does not contain any op
/// from the group.
static bool isHorizontalToGroup(Operation *op,
const llvm::SetVector<Operation *> &currGroup,
const DominanceInfo &dominanceInfo,
Operation *seedOp) {
BackwardSliceOptions options;
// Limit the slice to the seed to make sure the slice is small.
options.filter = [&](Operation *op) {
return !dominanceInfo.properlyDominates(op, seedOp);
};
llvm::SetVector<Operation *> slice;
getBackwardSlice(op, &slice, options);
return !llvm::any_of(currGroup, [&](Operation *groupedOp) {
return slice.contains(groupedOp);
});
}

/// Get user of operation that is a truncate operation.
static std::optional<linalg::GenericOp>
getTruncateOp(Operation *op,
const llvm::SetVector<Operation *> &groupedOperations,
const DominanceInfo &dominanceInfo,
std::optional<linalg::GenericOp> seedTruncateOp = std::nullopt) {
if (!op->hasOneUse()) {
return std::nullopt;
Expand All @@ -128,6 +149,9 @@ getTruncateOp(Operation *op,
if (!checkOperationEquivalence(genericOp, seedTruncateOp.value())) {
return std::nullopt;
}
if (!isHorizontalToGroup(genericOp, groupedOperations, dominanceInfo, seedTruncateOp.value())) {
return std::nullopt;
}
}
return genericOp;
}
Expand Down Expand Up @@ -168,7 +192,7 @@ static std::optional<HorizontalFusionGroup> getHorizontalFusionGroupMembers(

SetVector<Operation *> allOps;
SmallVector<linalg::LinalgOp> contractionOps = {seedOp};
std::optional<linalg::GenericOp> seedTruncOp = getTruncateOp(seedOp);
std::optional<linalg::GenericOp> seedTruncOp = getTruncateOp(seedOp, allOps, dominanceInfo);
std::optional<SmallVector<linalg::GenericOp>> truncateOps;
if (seedTruncOp) {
truncateOps = {seedTruncOp.value()};
Expand All @@ -182,21 +206,27 @@ static std::optional<HorizontalFusionGroup> getHorizontalFusionGroupMembers(
if (linalgOp->getParentOp() != seedOp->getParentOp()) {
return false;
}
// The seed has to dominate the op.
if (!dominanceInfo.properlyDominates(seedOp, linalgOp)) {
return false;
}

// Constraints of the operation itself.
if (!isEmptyFillContractionDAGRootOp(linalgOp, seedOp)) {
return false;
}
if (groupedOperations.contains(linalgOp) || allOps.contains(linalgOp)) {
return false;
}
if (linalgOp->getOperand(0).getType() != lhsType ||
linalgOp->getOperand(1).getType() != rhsType ||
linalgOp->getOperand(2).getType() != outType) {
return false;
}
if (groupedOperations.contains(linalgOp)) {
return false;
}

// Structural constraints related to being able to fuse the operations.
if (!dominanceInfo.properlyDominates(seedOp, linalgOp)) {
return false;
}
if (!isHorizontalToGroup(linalgOp, allOps, dominanceInfo, seedOp)) {
return false;
}
return true;
};

Expand Down Expand Up @@ -227,7 +257,7 @@ static std::optional<HorizontalFusionGroup> getHorizontalFusionGroupMembers(
}

std::optional<linalg::GenericOp> userTruncOp =
getTruncateOp(linalgUser, seedTruncOp);
getTruncateOp(linalgUser, allOps, dominanceInfo, seedTruncOp);
// If there are truncate ops to fuse and current contraction op
// does not have a compatible truncate op to fuse as well, ignore
// the op for horizontal fusion.
Expand Down

0 comments on commit e187f32

Please sign in to comment.