Skip to content

Commit

Permalink
[Codegen][GPU] Allow iree_gpu.barrier_region to take multiple operand…
Browse files Browse the repository at this point in the history
…s/results (#18490)

The restriction to a single input and output was artificial as this op
simply represents synchronization on input and output values.
Additionally this removes the restriction on tensor/vector types, but
for the time being this op is still only used with those types.
  • Loading branch information
qedawkins committed Sep 19, 2024
1 parent fa44a32 commit c9eca66
Show file tree
Hide file tree
Showing 7 changed files with 220 additions and 111 deletions.
32 changes: 20 additions & 12 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
Expand All @@ -32,38 +33,45 @@ namespace mlir::iree_compiler::IREE::GPU {

// Build a BarrierRegionOp with an empty.
void BarrierRegionOp::build(OpBuilder &b, OperationState &result,
Type resultType, Value dest) {
result.addOperands(dest);
TypeRange resultTypes, ValueRange inputs) {
result.addOperands(inputs);
(void)result.addRegion();
result.addTypes(resultType);
result.addTypes(resultTypes);
SmallVector<Location> blockArgLocs(inputs.size(), result.location);

Region *region = result.regions[0].get();

// `builder.createBlock` changes the insertion point within the block. Create
// a guard to reset the insertion point of the builder after it is destroyed.
OpBuilder::InsertionGuard guard(b);
b.createBlock(region, region->end(), ArrayRef<Type>{dest.getType()},
ArrayRef<Location>{result.location});
b.createBlock(region, region->end(), inputs.getTypes(), blockArgLocs);
}

LogicalResult BarrierRegionOp::verify() { return success(); }

LogicalResult BarrierRegionOp::verifyRegions() {
auto &region = getRegion();
Block &block = region.front();
if (block.getNumArguments() != 1) {
return emitError("expected the block to have a single argument");
if (block.getNumArguments() != getNumOperands()) {
return emitError(
"expected the block argument count to match operand count");
}

if (block.getArgumentTypes()[0] != getDestType()) {
return emitError("expected block to have single argument type of")
<< getDestType();
if (!llvm::all_of_zip(block.getArgumentTypes(), getOperandTypes(),
[](Type a, Type b) { return a == b; })) {
return emitError("expected block argument types to match operand types");
}

// Ensure that the region yields an element of the right type.
auto yieldOp = llvm::cast<GPU::YieldOp>(block.getTerminator());
if (yieldOp.getValue().getType() != getResult().getType()) {
return emitOpError("expected yield type to match result type");
if (yieldOp->getNumOperands() != getNumResults()) {
return emitOpError(
"expected body to yield same number of values as results");
}

if (!llvm::all_of_zip(yieldOp->getOperandTypes(), getResultTypes(),
[](Type a, Type b) { return a == b; })) {
return emitError("expected yielded value types to match result types");
}

return success();
Expand Down
53 changes: 24 additions & 29 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def IREEGPU_BarrierRegionOp : Op<IREEGPU_Dialect, "barrier_region", [
]> {
let summary = "Synchronizes uses of a shared tensor.";
let description = [{
This op is designed to represent synchronization of workers on a
particular shared tensor. This operation naturally arises when combining
This op is designed to represent synchronization of workers on the operands
and results of the given region. This operation naturally arises when combining
the regions of producer-consumer `scf.forall` operations that share a
mapping type.

Expand Down Expand Up @@ -58,27 +58,26 @@ def IREEGPU_BarrierRegionOp : Op<IREEGPU_Dialect, "barrier_region", [

```mlir
%0 = scf.forall (%idy, %idx) in (8, 8) -> (tensor<4x128xf32>) {
%ids = affine.delinearize_index %idy * 8 + %idx to (2, 32) : index
%in = ...
%2 = affine.apply #affine_map<(d0) -> (d0 * 2)> (%ids#0)
%3 = affine.apply #affine_map<(d0) -> (d0 * 4)> (%ids#1)
%4 = affine.apply #affine_map<(d0) -> (d0 * 16)> (%idx)
%alloc = bufferization.alloc_tensor {memory_space = #gpu.address_space<workgroup>}
: tensor<4x128xf32>
%inserted_slice = tensor.insert_slice %in into %alloc[%2, %3] [2, 4] [1, 1]
: tensor<2x4xf32> to tensor<4x128xf32>
%slice = iree_gpu.barrier_region %inserted_slice {
%barrier = iree_gpu.barrier_region %alloc {
^bb0(%shared: tensor<4x128xf32>):
%slice = tensor.extract_slice %shared[0, %4] [4, 16] [1, 1] : tensor<4x128xf32> to tensor<4x16xf32>
%ids = affine.delinearize_index %idy * 8 + %idx to (2, 32) : index
%in = ...
%2 = affine.apply #affine_map<(d0) -> (d0 * 2)> (%ids#0)
%3 = affine.apply #affine_map<(d0) -> (d0 * 4)> (%ids#1)
%inserted_slice = tensor.insert_slice %in into %shared[%2, %3] [2, 4] [1, 1]
: tensor<2x4xf32> to tensor<4x128xf32>
iree_gpu.yield %slice : tensor<4x16xf32>
} : tensor<4x128xf32> -> tensor<4x16xf32>
%4 = affine.apply #affine_map<(d0) -> (d0 * 16)> (%idx)
%slice = tensor.extract_slice %barrier[0, %4] [4, 16] [1, 1] : tensor<4x128xf32> to tensor<4x16xf32>
...
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
```

A barrier_region can be lowered to two barriers, one on the |dest| input
operand, and a second one on the result. Note that the |dest| operand must
bufferize with memory space `#gpu.address_space<workgroup>`.
A barrier_region can be lowered to two barriers, one on the input operands
and a second one on the results.

Movtivation and Intended Use Cases:

Expand All @@ -92,26 +91,20 @@ def IREEGPU_BarrierRegionOp : Op<IREEGPU_Dialect, "barrier_region", [
}];

let arguments = (ins
AnyRankedTensor:$dest
Variadic<AnyType>:$inputs
);
let regions = (region SizedRegion<1>:$region);
let results = (outs AnyRankedTensorOrVector:$result);
let results = (outs Variadic<AnyType>:$results);

let assemblyFormat = [{
$dest $region attr-dict
`:` type($dest) `->` type($result)
(`ins` `(` $inputs^ `:` type($inputs) `)` )?
$region attr-dict `:` type($results)
}];

let builders = [
OpBuilder<(ins "Type":$result_type, "Value":$dest)>
OpBuilder<(ins "TypeRange":$result_types, "ValueRange":$inputs)>
];

let extraClassDeclaration = [{
RankedTensorType getDestType() {
return getDest().getType();
}
}];

let skipDefaultBuilders = 1;
let hasVerifier = 1;
let hasRegionVerifier = 1;
Expand Down Expand Up @@ -448,14 +441,16 @@ def IREEGPU_ValueBarrierOp : Op<IREEGPU_Dialect, "value_barrier", [
def IREEGPU_YieldOp : Op<IREEGPU_Dialect, "yield", [
Pure, ReturnLike, Terminator,
HasParent<"::mlir::iree_compiler::IREE::GPU::BarrierRegionOp">]> {
let summary = "Yield a value from a region";
let summary = "Yield values from a region";
let description = [{
This operation is used to yield a single value from a within a region.
This operation is used to yield values from a within a region.
}];

let arguments = (ins AnyType:$value);
let assemblyFormat = "$value attr-dict `:` type($value)";
let arguments = (ins Variadic<AnyType>:$values);
let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>];

let assemblyFormat =
[{ attr-dict ($values^ `:` type($values))? }];
}

#endif // IREE_CODEGEN_DIALECT_IREEGPUOPS
Original file line number Diff line number Diff line change
@@ -1,39 +1,53 @@
// RUN: iree-opt %s --split-input-file | FileCheck %s

func.func @barrier_region(%init: tensor<6x6xf32>) -> tensor<3x2xf32> {
%0 = iree_gpu.barrier_region %init {
%0 = iree_gpu.barrier_region ins(%init : tensor<6x6xf32>) {
^bb0(%intermediate: tensor<6x6xf32>):
%slice = tensor.extract_slice %intermediate[0, 0] [3, 2] [1, 1] : tensor<6x6xf32> to tensor<3x2xf32>
iree_gpu.yield %slice : tensor<3x2xf32>
} : tensor<6x6xf32> -> tensor<3x2xf32>
} : tensor<3x2xf32>
return %0 : tensor<3x2xf32>
}

// CHECK-LABEL: func @barrier_region
// CHECK-SAME: %[[INIT:[A-Za-z0-9]+]]: tensor<6x6xf32>
// CHECK: iree_gpu.barrier_region %[[INIT]] {
// CHECK: iree_gpu.barrier_region ins(%[[INIT]] : tensor<6x6xf32>) {
// CHECK: ^bb0(%[[INTERMEDIATE:.+]]: tensor<6x6xf32>):
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[INTERMEDIATE]][0, 0] [3, 2] [1, 1]
// CHECK: iree_gpu.yield %[[SLICE]] : tensor<3x2xf32>
// CHECK: } : tensor<6x6xf32> -> tensor<3x2xf32>
// CHECK: } : tensor<3x2xf32>

// -----

func.func @reshape_barrier_region(%init: tensor<12x12xf32>) -> tensor<2x1x3x2xf32> {
%0 = iree_gpu.barrier_region %init {
func.func @multi_result_barrier_region(%init: tensor<12x12xf32>) -> (tensor<2x1x3x2xf32>, index) {
%0:2 = iree_gpu.barrier_region ins(%init : tensor<12x12xf32>) {
^bb0(%intermediate: tensor<12x12xf32>):
%expand = tensor.expand_shape %intermediate [[0, 1], [2, 3]] output_shape [4, 3, 3, 4] : tensor<12x12xf32> into tensor<4x3x3x4xf32>
%slice = tensor.extract_slice %expand[0, 0, 0, 0] [2, 1, 3, 2] [1, 1, 1, 1] : tensor<4x3x3x4xf32> to tensor<2x1x3x2xf32>
iree_gpu.yield %slice : tensor<2x1x3x2xf32>
} : tensor<12x12xf32> -> tensor<2x1x3x2xf32>
return %0 : tensor<2x1x3x2xf32>
%c0 = arith.constant 0 : index
iree_gpu.yield %slice, %c0 : tensor<2x1x3x2xf32>, index
} : tensor<2x1x3x2xf32>, index
return %0#0, %0#1 : tensor<2x1x3x2xf32>, index
}

// CHECK-LABEL: func @reshape_barrier_region
// CHECK: iree_gpu.barrier_region
// CHECK: tensor.expand_shape
// CHECK: tensor.extract_slice
// CHECK: } : tensor<12x12xf32> -> tensor<2x1x3x2xf32>
// CHECK-LABEL: func @multi_result_barrier_region
// CHECK: %{{.*}}:2 = iree_gpu.barrier_region ins(%{{.*}} : tensor<12x12xf32>)
// CHECK: } : tensor<2x1x3x2xf32>, index

// -----

func.func @multi_input_barrier_region(%x: index, %y: index) -> index {
%0 = iree_gpu.barrier_region ins(%x, %y : index, index) {
^bb0(%ix: index, %iy: index):
%sum = arith.addi %ix, %iy : index
iree_gpu.yield %sum : index
} : index
return %0 : index
}

// CHECK-LABEL: func @multi_input_barrier_region
// CHECK: %{{.*}} = iree_gpu.barrier_region ins(%{{.*}}, %{{.*}} : index, index)
// CHECK: } : index

// -----

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,11 @@ module attributes { transform.with_named_sequence } {
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[COPY]] into %[[ITER]][%[[INID0]], %[[IDS]]#0] [2, 128] [1, 1]
// CHECK: scf.yield %[[INSERT]]

// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region %[[LOOP]]
// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region ins(%[[LOOP]] : tensor<128x128xf32>)
// CHECK: ^bb0(%[[INTERMEDIATE:.+]]: tensor<128x128xf32>):
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[INTERMEDIATE]][%[[OUTID0]], %[[OUTID1]]] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
// CHECK: iree_gpu.yield %[[SLICE]]
// CHECK: } : tensor<128x128xf32> -> tensor<16x16xf32>
// CHECK: } : tensor<16x16xf32>
// CHECK: %[[OUTSLICE:.+]] = tensor.extract_slice %[[INIT]][%[[OUTID0]], %[[OUTID1]]] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
// CHECK: %[[MM:.+]] = linalg.matmul ins(%[[SHUFFLE]], %[[SHUFFLE]] : tensor<16x16xf32>, tensor<16x16xf32>)
// CHECK-SAME: outs(%[[OUTSLICE]] : tensor<16x16xf32>) -> tensor<16x16xf32>
Expand Down Expand Up @@ -124,8 +124,8 @@ module attributes { transform.with_named_sequence } {
// CHECK: scf.forall (%[[IDX:.+]], %[[IDY:.+]]) in (8, 8) shared_outs(%[[INIT:.+]] = %[[EMPTY]]) -> (tensor<128x128xf32>) {
// CHECK: %[[LOOP:.+]] = scf.for {{.*}} iter_args(%[[INIT:.+]] = %[[ALLOC]])
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %{{.*}} into %[[INIT]]
// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region %[[LOOP]]
// CHECK: } : tensor<128x128xf32> -> tensor<16x16xf32>
// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region ins(%[[LOOP]] : tensor<128x128xf32>)
// CHECK: } : tensor<16x16xf32>
// CHECK: } {mapping = [#gpu.warp<y>, #gpu.warp<x>]}

// -----
Expand Down Expand Up @@ -180,12 +180,12 @@ module attributes { transform.with_named_sequence } {
// CHECK: scf.forall (%[[IDX:.+]], %[[IDY:.+]]) in (8, 8) shared_outs(%[[INIT:.+]] = %[[EMPTY]]) -> (tensor<128x128xf32>) {
// CHECK: %[[LOOP:.+]] = scf.for {{.*}} iter_args(%[[INIT:.+]] = %[[ALLOC]])
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %{{.*}} into %[[INIT]]
// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region %[[LOOP]]
// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region ins(%[[LOOP]] : tensor<128x128xf32>)
// CHECK: ^bb0(%[[INTERMEDIATE:.+]]: tensor<128x128xf32>):
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[INTERMEDIATE]] {{\[}}[0, 1], [2]{{\]}} output_shape [2, 64, 128]
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[EXPAND]][0, %{{.*}}, %{{.*}}] [1, 16, 16] [1, 1, 1] : tensor<2x64x128xf32> to tensor<16x16xf32>
// CHECK: iree_gpu.yield %[[SLICE]]
// CHECK: } : tensor<128x128xf32> -> tensor<16x16xf32>
// CHECK: } : tensor<16x16xf32>
// CHECK: } {mapping = [#gpu.warp<y>, #gpu.warp<x>]}

// -----
Expand Down Expand Up @@ -253,8 +253,8 @@ module attributes { transform.with_named_sequence } {
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[COPY]] into %[[ITER]][%[[IDX]], %[[IDS]]#0] [2, 128]
// CHECK: scf.yield %[[INSERT]]

// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region %[[LOOP]]
// CHECK: } : tensor<128x128xf32> -> tensor<16x16xf32>
// CHECK: %[[SHUFFLE:.+]] = iree_gpu.barrier_region ins(%[[LOOP]] : tensor<128x128xf32>)
// CHECK: } : tensor<16x16xf32>
// CHECK: } {mapping = [#iree_gpu.lane_id<1>, #iree_gpu.lane_id<0>]}
// CHECK: } {mapping = [#gpu.warp<y>, #gpu.warp<x>]}

Expand Down Expand Up @@ -308,7 +308,7 @@ module attributes { transform.with_named_sequence } {
// CHECK: %[[LOOP:.+]] = scf.for %[[I:.+]] = %[[LINEARID]] to %c32{{.*}} step %c64{{.*}} iter_args(%[[ITER:.+]] = %[[ALLOC]])
// CHECK: %[[IDS:.+]] = affine.delinearize_index %[[I]] into (%c32) : index
// CHECK: scf.yield
// CHECK: iree_gpu.barrier_region %[[LOOP]]
// CHECK: iree_gpu.barrier_region ins(%[[LOOP]]
// CHECK: } {mapping = [#gpu.thread<y>, #gpu.thread<x>]}

// -----
Expand Down Expand Up @@ -363,5 +363,5 @@ module attributes { transform.with_named_sequence } {
// CHECK: %[[LOOP:.+]] = scf.for %[[I:.+]] = %[[LINEARID]] to %[[PRODCOUNT]] step %c64{{.*}} iter_args(%[[ITER:.+]] = %[[ALLOC]])
// CHECK: %[[IDS:.+]] = affine.delinearize_index %[[I]] into (%[[Z]], %[[Y]], %[[X]]) : index
// CHECK: scf.yield
// CHECK: iree_gpu.barrier_region %[[LOOP]]
// CHECK: iree_gpu.barrier_region ins(%[[LOOP]]
// CHECK: } {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// RUN: iree-opt %s -iree-transform-dialect-interpreter -transform-dialect-drop-schedule --split-input-file | FileCheck %s

func.func @barrier_region(%init: tensor<6x6xf32>, %x: index) -> tensor<3x2xf32> {
%0 = iree_gpu.barrier_region %init {
%0 = iree_gpu.barrier_region ins(%init : tensor<6x6xf32>) {
^bb0(%intermediate: tensor<6x6xf32>):
%slice = tensor.extract_slice %intermediate[0, %x] [3, 2] [1, 1] : tensor<6x6xf32> to tensor<3x2xf32>
iree_gpu.yield %slice : tensor<3x2xf32>
} : tensor<6x6xf32> -> tensor<3x2xf32>
} : tensor<3x2xf32>
return %0 : tensor<3x2xf32>
}

Expand Down Expand Up @@ -33,12 +33,12 @@ module attributes { transform.with_named_sequence } {
func.func @reshape_barrier_region(%init: tensor<12x12xf32>) -> vector<2x1x3x2xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.0 : f32
%0 = iree_gpu.barrier_region %init {
%0 = iree_gpu.barrier_region ins(%init : tensor<12x12xf32>) {
^bb0(%intermediate: tensor<12x12xf32>):
%expand = tensor.expand_shape %intermediate [[0, 1], [2, 3]] output_shape [4, 3, 3, 4] : tensor<12x12xf32> into tensor<4x3x3x4xf32>
%read = vector.transfer_read %expand[%c0, %c0, %c0, %c0], %cst : tensor<4x3x3x4xf32>, vector<2x1x3x2xf32>
iree_gpu.yield %read : vector<2x1x3x2xf32>
} : tensor<12x12xf32> -> vector<2x1x3x2xf32>
} : vector<2x1x3x2xf32>
return %0 : vector<2x1x3x2xf32>
}

Expand All @@ -59,3 +59,31 @@ module attributes { transform.with_named_sequence } {
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[WRITE_BARRIER]]
// CHECK: %[[READ:.+]] = vector.transfer_read %[[EXPAND]]
// CHECK: %[[READ_BARRIER:.+]] = iree_gpu.value_barrier %[[READ]]

// -----

func.func @multi_barrier_region(%arg0: tensor<2xf32>, %arg1: tensor<3xf32>) -> (tensor<3xf32>, tensor<2xf32>) {
%0:2 = iree_gpu.barrier_region ins(%arg0, %arg1 : tensor<2xf32>, tensor<3xf32>) {
^bb0(%in0: tensor<2xf32>, %in1: tensor<3xf32>):
iree_gpu.yield %in1, %in0 : tensor<3xf32>, tensor<2xf32>
} : tensor<3xf32>, tensor<2xf32>
return %0#0, %0#1 : tensor<3xf32>, tensor<2xf32>
}

module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.iree.lower_barrier_region
} : !transform.any_op
transform.yield
}
}

// CHECK-LABEL: func @multi_barrier_region
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<2xf32>
// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: tensor<3xf32>

// CHECK: %[[WB:.+]]:2 = iree_gpu.value_barrier %[[ARG0]], %[[ARG1]]
// CHECK: %[[RB:.+]]:2 = iree_gpu.value_barrier %[[WB]]#1, %[[WB]]#0
// CHECK: return %[[RB]]#0, %[[RB]]#1 : tensor<3xf32>, tensor<2xf32>
Loading

0 comments on commit c9eca66

Please sign in to comment.