Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flow] Generalize horizontal contraction fusion to cover more cases. #17880

Conversation

MaheshRavishankar
Copy link
Contributor

Current implementation of horizontal fusion worked only for
contraction (-> truncf)? cases. To generalize it to more cases the
following changes were needed

  1. Instead of looking for a linalg.generic with a single
    arith.truncf, use isBitTruncate utility method. To enable this
    the pass is moved to Flow.

  2. Use Operation::Equivalence to check that the contraction
    operations are "similar" and any subsequent truncation operations
    are "similar" too.

  3. Instead of trying to find an insertion point based on existing
    dominance relationship between operations,

    • Always insert the horizontally fused contraction before the first
      contraction
    • Always insert the horizontally fused truncation operation before
      the first truncation operation
  4. Instead of generating the fills/empty along with horizontally fused
    operations, use separate patterns to fold concats of fills and
    concats of emptys into fill and empty, respectively.

Signed-off-by: MaheshRavishankar mahesh.ravishankar@gmail.com

@MaheshRavishankar
Copy link
Contributor Author

I tried to split this work in two. Might help reviewing the first and second commits separately. First one is mostly a move of the pass from GlobalOpt -> Flow. The second is a functional change to the pass. (Still need to add a test for the enhancements, working on it)

Copy link
Collaborator

@benvanik benvanik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please no :)
As I noted on @IanWood1's PR the other day, that isBitTruncate is in RegionOpUtils is a bug - it should be in LinalgExt Utils/ - and this move seems to be motivated by the fact that it's in RegionOpUtils.

We want more things like this pass out of flow and not moved into it. As a general rule we really need to move towards a default where unless something is operating on flow ops it should not be in flow - something doing global optimizations on the program with tensor/linalg ops should be before flow in global opt.

@MaheshRavishankar
Copy link
Contributor Author

Please no :) As I noted on @IanWood1's PR the other day, that isBitTruncate is in RegionOpUtils is a bug - it should be in LinalgExt Utils/ - and this move seems to be motivated by the fact that it's in RegionOpUtils.

We want more things like this pass out of flow and not moved into it. As a general rule we really need to move towards a default where unless something is operating on flow ops it should not be in flow - something doing global optimizations on the program with tensor/linalg ops should be before flow in global opt.

I saw the notification but didnt respond to it

  1. I can move bitExtend/bitTruncate out, but I dont fully understand why. This is kind of needed for dispatch region formation if we want to ensure that extensions get fused with consumers and truncations get fused with producers, so given that is what happens in flow, it seems like it should be in flow. (I am not opposed to putting it in LinalgExt, but seems strange to me)
  2. I think my PR description is off... the move of this to flow is actually more than because bit truncation. The sequence of operations that are done for conversion from the compute type to storage type arent fused. Will make the analysis and transformation much more complicated. It is better to do it after elementwise operation fusion. So this pass needs to run after elementwise operation fusion to be really useful. We can discuss moving the entire elementwise operation + reshape propagation passes out of flow, but thats a separate discussion.

@benvanik
Copy link
Collaborator

benvanik commented Jul 12, 2024

RE 1: Ian wanted to use Flow's RegionOpUtils - a file for utilities related to flow dispatch regions - in common Util dialect code operating strictly on linalg ops (not flow ops). Here you're using Flow's RegionOpUtils - a file for utilities related to flow dispatch regions - in global opt operating strictly on linalg ops (not flow ops). So same reason in both cases: whether a linalg op has a certain trait is something related to linalg, not to flow, and anywhere we may want to ask about that linalg op should not have to pull in RegionOpUtils or depend on the flow dialect. Upstream in the linalg dialect would be a logical place for it but I know how tricky that can be, so LinalgExt (our place for stuff not in upstream linalg) is the closest appropriate place.

RE 2 makes sense to me, and as you say I think what we need to do is lop off the entire flow pipeline dealing with linalg prior to dispatch region formation and move that to global opt. If this is a transient step because of that then that's unfortunate but good enough reason. We really should work towards moving all this out, though - the flow pipeline should essentially start after dispatch region formation with the input to flow being formed dispatch regions. I suspect we want a new pipeline that sits between global opt and flow that is dedicated to dispatch region formation but don't know where to put it - global opt feels more appropriate as a staging ground. My hope is at some point we start by splitting the passes into two pipelines inside of flow, then move all the passes in that first pipeline forming dispatch regions out wholesale. That'd be really really good cleanup and help avoid situations like this where something really belongs in global opt but can't due to things being in flow.

@benvanik benvanik self-requested a review July 12, 2024 01:35
@MaheshRavishankar
Copy link
Contributor Author

RE 1: Ian wanted to use Flow's RegionOpUtils - a file for utilities related to flow dispatch regions - in common Util dialect code operating strictly on linalg ops (not flow ops). Here you're using Flow's RegionOpUtils - a file for utilities related to flow dispatch regions - in global opt operating strictly on linalg ops (not flow ops). So same reason in both cases: whether a linalg op has a certain trait is something related to linalg, not to flow, and anywhere we may want to ask about that linalg op should not have to pull in RegionOpUtils or depend on the flow dialect. Upstream in the linalg dialect would be a logical place for it but I know how tricky that can be, so LinalgExt (our place for stuff not in upstream linalg) is the closest appropriate place.

Found Ian's PR. I am happy to rebase on top of that and land this after that.

RE 2 makes sense to me, and as you say I think what we need to do is lop off the entire flow pipeline dealing with linalg prior to dispatch region formation and move that to global opt. If this is a transient step because of that then that's unfortunate but good enough reason. We really should work towards moving all this out, though - the flow pipeline should essentially start after dispatch region formation with the input to flow being formed dispatch regions. I suspect we want a new pipeline that sits between global opt and flow that is dedicated to dispatch region formation but don't know where to put it - global opt feels more appropriate as a staging ground. My hope is at some point we start by splitting the passes into two pipelines inside of flow, then move all the passes in that first pipeline forming dispatch regions out wholesale. That'd be really really good cleanup and help avoid situations like this where something really belongs in global opt but can't due to things being in flow.

Yes. I think thats a good idea. I started doing this, but havent fully finished yet. Once I get some time to breathe, I can cleanup and improve documentation of it. It would be really nice to have a pipeline saying "--iree-flow-form-default-dispatch-regions" that just does what it does today, and we can have others if need be. It is transient, but I will try to finish that up before I switch into the next thing. I am scared of GLobalOpt. Has become something of a dumping ground.

@qedawkins
Copy link
Contributor

RE 2 makes sense to me, and as you say I think what we need to do is lop off the entire flow pipeline dealing with linalg prior to dispatch region formation and move that to global opt. If this is a transient step because of that then that's unfortunate but good enough reason. We really should work towards moving all this out, though - the flow pipeline should essentially start after dispatch region formation with the input to flow being formed dispatch regions. I suspect we want a new pipeline that sits between global opt and flow that is dedicated to dispatch region formation but don't know where to put it - global opt feels more appropriate as a staging ground. My hope is at some point we start by splitting the passes into two pipelines inside of flow, then move all the passes in that first pipeline forming dispatch regions out wholesale. That'd be really really good cleanup and help avoid situations like this where something really belongs in global opt but can't due to things being in flow.

If we plan to do a split like this, I think it would be worth putting up an issue describing what we plan to gain by splitting Flow/Transforms/Passes.cpp into separate pipelines. A few high level bullets I've seen come up:

  1. Device placement seems to be designed to happen after dispatch region formation, but before Stream where we no longer have good visibility into the dispatches.
  2. Same for setting encodings. The difference with encodings is that there will inherently need to be some dependence on upstream dialects, perhaps facilitated with interfaces, but is a separate dialect that (I'm assuming) Flow would take a dependence on.
  3. We could consider moving constant expression hoisting after dispatch region formation. This way any operations that are preferred to be fused with consumers (BitExtend/broadcast/etc...) can be controlled by dispatch region formation rather than complex/opinionated hoisting interface implementations. We can always have hoisting as a part of multiple phases either way, and we'd definitely want it after setting encodings/device placement regardless.

My hope is at some point we start by splitting the passes into two pipelines inside of flow, then move all the passes in that first pipeline forming dispatch regions out wholesale. That'd be really really good cleanup and help avoid situations like this where something really belongs in global opt but can't due to things being in flow.

I agree, and it seems that Mahesh agrees, but I don't think it will be easy to make progress without writing out the why somewhere. Echoing Benoit's statement in another thread, issues are a better place for broad design discussion. One aspect of what I think @benvanik you're asking for is why we want to invest so heavily in splitting dependencies/adding interfaces. What I've gathered from public discussions is that much of the reason for interfaces over a dependence where possible is to allow for plugging in downstream dialects. I haven't seen a design doc that describes what a downstream dialect would need to/could do to plug in to IREE, and adding a doc like could be excellent to point at in discussions like this (maybe something like that exists and I missed it).

There also is a certain degree of not having a good place to put things related to upstream dialects. If IREE is intended as an upstream (with an actual e2e flow instead of just a basket of lit tests), then we need to define a set of critical interfaces/transformations a dialect needs to implement for IREE to know how to work with it. I've seen good samples for this for Stream/HAL, but I haven't seen the same for the kinds of transformations Flow does (dispatch region formation). I see that as a good prerequisite to splitting apart Flow's transforms and also starting to unwind the GlobalOptimization kitchen sink.

@MaheshRavishankar MaheshRavishankar force-pushed the sdxl_quantized_horizontal_contraction_fusion branch from eac0af6 to 16f7ebf Compare July 12, 2024 06:10
@MaheshRavishankar MaheshRavishankar force-pushed the sdxl_quantized_horizontal_contraction_fusion branch 2 times, most recently from 6947ab0 to fa85e73 Compare August 1, 2024 06:18
@MaheshRavishankar MaheshRavishankar marked this pull request as ready for review August 1, 2024 06:18
@MaheshRavishankar
Copy link
Contributor Author

Marking this ready for review. Also this is for now in Flow, but will need to be moved to a separate pass pipeline that deals with dispatch region formation. See #18063 (comment)

@IanWood1
Copy link
Contributor

IanWood1 commented Aug 8, 2024

The regression test is failing because of a dominance error during the pass.

Consider this example (actual mlir here IanWood1@de471fd):

%0 = matmul %arg0, %arg1
%1 = truncf %0
%2 = matmul %arg0, %1
%1 = truncf %2

%2 gets moved (via fusion) above %1 meaning that %1 is used before it is defined.

I think we just need to check: if the operand is defined by an operation, that operation must properly dominate all other operations in the fusion group. This would be a more conservative approach but I think it would ensure correctness.

Edit: The actual issue is the use-def relationship between the ops. %1's operand would get moved above the fused op if there was no dependency (this is tested in util.func @horizontal_fusion_i8).

@MaheshRavishankar
Copy link
Contributor Author

Thanks Ian! Could you post a snippet of the failing IR

@IanWood1
Copy link
Contributor

Its actually the truncation ops that have the dependency in this case. https://gist.github.com/IanWood1/f4ab6642601c6e005187f6e8e3c82a05

%3390 = linalg.matmul_transpose_b ins(%65, %_params.unet.up_blocks.2.resnets.0.time_emb_proj.weight : tensor<2x1280xf16>, tensor<320x1280xf16>) outs(%66 : tensor<2x320xf32>) -> tensor<2x320xf32>
%3391 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%3389, %__hoisted_tensor_320xf32_371, %3390, %__hoisted_tensor_320xf32_372 : tensor<2x128x128x320xf32>, tensor<320xf32>, tensor<2x320xf32>, tensor<320xf32>) outs(%49 : tensor<2x320x128x128xf16>) {
^bb0(%in: f32, %in_2002: f32, %in_2003: f32, %in_2004: f32, %out: f16):
  %3459 = arith.addf %in_2003, %in_2004 : f32
  %3460 = arith.addf %in, %in_2002 : f32
  %3461 = arith.truncf %3459 : f32 to f16
  %3462 = arith.truncf %3460 : f32 to f16
  %3463 = arith.addf %3462, %3461 : f16
  linalg.yield %3463 : f16
} -> tensor<2x320x128x128xf16>
// omitted long chain of IR that connects %3391 and %3411
// ...
%3411 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%inserted_slice_1973, %__hoisted_tensor_3x3x640x320xf16 : tensor<2x130x130x640xf16>, tensor<3x3x640x320xf16>) outs(%45 : tensor<2x128x128x320xf32>) -> tensor<2x128x128x320xf32>
%3412 = linalg.matmul_transpose_b ins(%65, %_params.unet.up_blocks.2.resnets.1.time_emb_proj.weight : tensor<2x1280xf16>, tensor<320x1280xf16>) outs(%66 : tensor<2x320xf32>) -> tensor<2x320xf32>
%3413 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%3411, %__hoisted_tensor_320xf32_376, %3412, %__hoisted_tensor_320xf32_377 : tensor<2x128x128x320xf32>, tensor<320xf32>, tensor<2x320xf32>, tensor<320xf32>) outs(%49 : tensor<2x320x128x128xf16>) {
^bb0(%in: f32, %in_2002: f32, %in_2003: f32, %in_2004: f32, %out: f16):
  %3459 = arith.addf %in_2003, %in_2004 : f32
  %3460 = arith.addf %in, %in_2002 : f32
  %3461 = arith.truncf %3459 : f32 to f16
  %3462 = arith.truncf %3460 : f32 to f16
  %3463 = arith.addf %3462, %3461 : f16
  linalg.yield %3463 : f16
} -> tensor<2x320x128x128xf16>

@MaheshRavishankar MaheshRavishankar force-pushed the sdxl_quantized_horizontal_contraction_fusion branch from 1287aa6 to 782178f Compare August 12, 2024 17:07
@MaheshRavishankar
Copy link
Contributor Author

Thanks @IanWood1 . Took your snippet and created as stand alone repro for it

util.func @repro(
    %65 : tensor<2x1280xf16>,
    %_params.unet.up_blocks.2.resnets.0.time_emb_proj.weight : tensor<320x1280xf16>,
    %3389 : tensor<2x128x128x320xf32>,
    %__hoisted_tensor_320xf32_371 : tensor<320xf32>,
    %__hoisted_tensor_320xf32_372 : tensor<320xf32>,
    %1962 : tensor<2x320x128x128xf16>,
    %_params.unet.up_blocks.2.resnets.0.norm2.weight : tensor<320xf16>,
    %_params.unet.up_blocks.2.resnets.0.norm2.bias : tensor<320xf16>,
    %__hoisted_tensor_3x3x320x320xf16_374 : tensor<3x3x320x320xf16>,
    %__hoisted_tensor_1x1x960x320xf16 : tensor<1x1x960x320xf16>,
    %__hoisted_tensor_320xf32_375 : tensor<320xf32>,
    %__hoisted_tensor_320xf32_373 : tensor<320xf32>,
    %76 : tensor<2x128x128x320xf16>,
    %_params.unet.up_blocks.2.resnets.1.norm1.weight : tensor<640xf16>,
    %_params.unet.up_blocks.2.resnets.1.norm1.bias : tensor<640xf16>,
    %__hoisted_tensor_3x3x640x320xf16 : tensor<3x3x640x320xf16>,
    %__hoisted_tensor_320xf32_376 : tensor<320xf32>,
    %__hoisted_tensor_320xf32_377 : tensor<320xf32>,
    %inserted_slice_1956 : tensor<960x2x128x128xf16>,
    %_params.unet.up_blocks.2.resnets.1.time_emb_proj.weight : tensor<320x1280xf16>) -> tensor<2x320x128x128xf16> {
  %cst = arith.constant 0.0 : f32
  %cst_0 = arith.constant 0.0 : f32
  %cst16_0 = arith.constant 0.0 : f16
  %cst_11 = arith.constant 11.0 : f32
  %cst_14 = arith.constant 14.0 : f32
  %cst_17 = arith.constant 17.0 : f16
  %44 = tensor.empty() : tensor<2x128x128x320xf32>
  %45 = linalg.fill ins(%cst : f32) outs(%44 : tensor<2x128x128x320xf32>) -> tensor<2x128x128x320xf32>
  %49 = tensor.empty() : tensor<2x320x128x128xf16>
  %50 = tensor.empty() : tensor<2x128x128x320xf16>
  %53 = tensor.empty() : tensor<2x32x10x16384xf32>
  %55 = tensor.empty() : tensor<2x32xf32>
  %56 = linalg.fill ins(%cst : f32) outs(%55 : tensor<2x32xf32>) -> tensor<2x32xf32>
  %62 = tensor.empty() : tensor<2x130x130x320xf16>
  %63 = linalg.fill ins(%cst16_0 : f16) outs(%62 : tensor<2x130x130x320xf16>) -> tensor<2x130x130x320xf16>
  %64 = tensor.empty() : tensor<2x320xf32>
  %66 = linalg.fill ins(%cst : f32) outs(%64 : tensor<2x320xf32>) -> tensor<2x320xf32>
  %3364 = tensor.empty() : tensor<2x640x128x128xf16>
  %3366 = tensor.empty() : tensor<2x128x128x640xf16>
  %3369 = tensor.empty() : tensor<2x130x130x640xf16>
  %3372 = tensor.empty() : tensor<640x2x128x128xf16>
  %3374 = tensor.empty() : tensor<320x2x128x128xf16>
  %3385 = tensor.empty() : tensor<2x128x128x960xf16>
  %3390 = linalg.matmul_transpose_b ins(%65, %_params.unet.up_blocks.2.resnets.0.time_emb_proj.weight : tensor<2x1280xf16>, tensor<320x1280xf16>) outs(%66 : tensor<2x320xf32>) -> tensor<2x320xf32>
  %3391 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%3389, %__hoisted_tensor_320xf32_371, %3390, %__hoisted_tensor_320xf32_372 : tensor<2x128x128x320xf32>, tensor<320xf32>, tensor<2x320xf32>, tensor<320xf32>) outs(%49 : tensor<2x320x128x128xf16>) {
  ^bb0(%in: f32, %in_2002: f32, %in_2003: f32, %in_2004: f32, %out: f16):
    %3459 = arith.addf %in_2003, %in_2004 : f32
    %3460 = arith.addf %in, %in_2002 : f32
    %3461 = arith.truncf %3459 : f32 to f16
    %3462 = arith.truncf %3460 : f32 to f16
    %3463 = arith.addf %3462, %3461 : f16
    linalg.yield %3463 : f16
  } -> tensor<2x320x128x128xf16>
  %collapsed_1962 = tensor.collapse_shape %3391 [[0], [1], [2, 3]] : tensor<2x320x128x128xf16> into tensor<2x320x16384xf16>
  %expanded_1963 = tensor.expand_shape %collapsed_1962 [[0], [1, 2], [3]] output_shape [2, 32, 10, 16384] : tensor<2x320x16384xf16> into tensor<2x32x10x16384xf16>
  %3392 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_1963 : tensor<2x32x10x16384xf16>) outs(%53 : tensor<2x32x10x16384xf32>) {
  ^bb0(%in: f16, %out: f32):
    %3459 = arith.extf %in : f16 to f32
    linalg.yield %3459 : f32
  } -> tensor<2x32x10x16384xf32>
  %3393 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%3392 : tensor<2x32x10x16384xf32>) outs(%56 : tensor<2x32xf32>) {
  ^bb0(%in: f32, %out: f32):
    %3459 = arith.addf %in, %out : f32
    linalg.yield %3459 : f32
  } -> tensor<2x32xf32>
  %3394 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%3393 : tensor<2x32xf32>) outs(%55 : tensor<2x32xf32>) {
  ^bb0(%in: f32, %out: f32):
    %3459 = arith.divf %in, %cst_11 : f32
    linalg.yield %3459 : f32
  } -> tensor<2x32xf32>
  %3395 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%3392, %3394 : tensor<2x32x10x16384xf32>, tensor<2x32xf32>) outs(%56 : tensor<2x32xf32>) {
  ^bb0(%in: f32, %in_2002: f32, %out: f32):
    %3459 = arith.subf %in, %in_2002 : f32
    %3460 = arith.mulf %3459, %3459 : f32
    %3461 = arith.addf %3460, %out : f32
    linalg.yield %3461 : f32
  } -> tensor<2x32xf32>
  %3396 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_1963, %3394, %3395 : tensor<2x32x10x16384xf16>, tensor<2x32xf32>, tensor<2x32xf32>) outs(%53 : tensor<2x32x10x16384xf32>) {
  ^bb0(%in: f16, %in_2002: f32, %in_2003: f32, %out: f32):
    %3459 = arith.divf %in_2003, %cst_11 : f32
    %3460 = arith.addf %3459, %cst_14 : f32
    %3461 = math.rsqrt %3460 : f32
    %3462 = arith.extf %in : f16 to f32
    %3463 = arith.subf %3462, %in_2002 : f32
    %3464 = arith.mulf %3463, %3461 : f32
    linalg.yield %3464 : f32
  } -> tensor<2x32x10x16384xf32>
  %collapsed_1964 = tensor.collapse_shape %3396 [[0], [1, 2], [3]] : tensor<2x32x10x16384xf32> into tensor<2x320x16384xf32>
  %expanded_1965 = tensor.expand_shape %collapsed_1964 [[0], [1], [2, 3]] output_shape [2, 320, 128, 128] : tensor<2x320x16384xf32> into tensor<2x320x128x128xf32>
  %3397 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d3)>, affine_map<(d0, d1, d2, d3) -> (d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_1965, %_params.unet.up_blocks.2.resnets.0.norm2.weight, %_params.unet.up_blocks.2.resnets.0.norm2.bias : tensor<2x320x128x128xf32>, tensor<320xf16>, tensor<320xf16>) outs(%50 : tensor<2x128x128x320xf16>) {
  ^bb0(%in: f32, %in_2002: f16, %in_2003: f16, %out: f16):
    %3459 = arith.extf %in_2002 : f16 to f32
    %3460 = arith.mulf %in, %3459 : f32
    %3461 = arith.extf %in_2003 : f16 to f32
    %3462 = arith.addf %3460, %3461 : f32
    %3463 = arith.truncf %3462 : f32 to f16
    %3464 = arith.negf %3463 : f16
    %3465 = math.exp %3464 : f16
    %3466 = arith.addf %3465, %cst_17 : f16
    %3467 = arith.divf %cst_17, %3466 : f16
    %3468 = arith.mulf %3467, %3463 : f16
    linalg.yield %3468 : f16
  } -> tensor<2x128x128x320xf16>
  %inserted_slice_1966 = tensor.insert_slice %3397 into %63[0, 1, 1, 0] [2, 128, 128, 320] [1, 1, 1, 1] : tensor<2x128x128x320xf16> into tensor<2x130x130x320xf16>
  %3398 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%inserted_slice_1966, %__hoisted_tensor_3x3x320x320xf16_374 : tensor<2x130x130x320xf16>, tensor<3x3x320x320xf16>) outs(%45 : tensor<2x128x128x320xf32>) -> tensor<2x128x128x320xf32>
  %3399 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%inserted_slice_1956 : tensor<960x2x128x128xf16>) outs(%3385 : tensor<2x128x128x960xf16>) {
  ^bb0(%in: f16, %out: f16):
    linalg.yield %in : f16
  } -> tensor<2x128x128x960xf16>
  %3400 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%3399, %__hoisted_tensor_1x1x960x320xf16 : tensor<2x128x128x960xf16>, tensor<1x1x960x320xf16>) outs(%45 : tensor<2x128x128x320xf32>) -> tensor<2x128x128x320xf32>
  %3401 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2, d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0)>, affine_map<(d0, d1, d2, d3) -> (d1, d2, d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%3400, %__hoisted_tensor_320xf32_375, %3398, %__hoisted_tensor_320xf32_373 : tensor<2x128x128x320xf32>, tensor<320xf32>, tensor<2x128x128x320xf32>, tensor<320xf32>) outs(%3374 : tensor<320x2x128x128xf16>) {
  ^bb0(%in: f32, %in_2002: f32, %in_2003: f32, %in_2004: f32, %out: f16):
    %3459 = arith.addf %in_2003, %in_2004 : f32
    %3460 = arith.addf %in, %in_2002 : f32
    %3461 = arith.truncf %3459 : f32 to f16
    %3462 = arith.truncf %3460 : f32 to f16
    %3463 = arith.addf %3462, %3461 : f16
    linalg.yield %3463 : f16
  } -> tensor<320x2x128x128xf16>
  %3402 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2, d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%76 : tensor<2x128x128x320xf16>) outs(%3374 : tensor<320x2x128x128xf16>) {
  ^bb0(%in: f16, %out: f16):
    linalg.yield %in : f16
  } -> tensor<320x2x128x128xf16>
  %inserted_slice_1967 = tensor.insert_slice %3401 into %3372[0, 0, 0, 0] [320, 2, 128, 128] [1, 1, 1, 1] : tensor<320x2x128x128xf16> into tensor<640x2x128x128xf16>
  %inserted_slice_1968 = tensor.insert_slice %3402 into %inserted_slice_1967[320, 0, 0, 0] [320, 2, 128, 128] [1, 1, 1, 1] : tensor<320x2x128x128xf16> into tensor<640x2x128x128xf16>
  %3403 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%inserted_slice_1968 : tensor<640x2x128x128xf16>) outs(%3364 : tensor<2x640x128x128xf16>) {
  ^bb0(%in: f16, %out: f16):
    linalg.yield %in : f16
  } -> tensor<2x640x128x128xf16>
  %collapsed_1969 = tensor.collapse_shape %3403 [[0], [1], [2, 3]] : tensor<2x640x128x128xf16> into tensor<2x640x16384xf16>
  %expanded_1970 = tensor.expand_shape %collapsed_1969 [[0], [1, 2], [3]] output_shape [2, 32, 20, 16384] : tensor<2x640x16384xf16> into tensor<2x32x20x16384xf16>
  %3404 = tensor.empty() : tensor<2x32x20x16384xf32>
  %3405 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_1970 : tensor<2x32x20x16384xf16>) outs(%3404 : tensor<2x32x20x16384xf32>) {
  ^bb0(%in: f16, %out: f32):
    %3459 = arith.extf %in : f16 to f32
    linalg.yield %3459 : f32
  } -> tensor<2x32x20x16384xf32>
  %3406 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%3405 : tensor<2x32x20x16384xf32>) outs(%56 : tensor<2x32xf32>) {
  ^bb0(%in: f32, %out: f32):
    %3459 = arith.addf %in, %out : f32
    linalg.yield %3459 : f32
  } -> tensor<2x32xf32>
  %3407 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%3406 : tensor<2x32xf32>) outs(%55 : tensor<2x32xf32>) {
  ^bb0(%in: f32, %out: f32):
    %3459 = arith.divf %in, %cst_0 : f32
    linalg.yield %3459 : f32
  } -> tensor<2x32xf32>
  %3408 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%3405, %3407 : tensor<2x32x20x16384xf32>, tensor<2x32xf32>) outs(%56 : tensor<2x32xf32>) {
  ^bb0(%in: f32, %in_2002: f32, %out: f32):
    %3459 = arith.subf %in, %in_2002 : f32
    %3460 = arith.mulf %3459, %3459 : f32
    %3461 = arith.addf %3460, %out : f32
    linalg.yield %3461 : f32
  } -> tensor<2x32xf32>
  %3409 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_1970, %3407, %3408 : tensor<2x32x20x16384xf16>, tensor<2x32xf32>, tensor<2x32xf32>) outs(%3404 : tensor<2x32x20x16384xf32>) {
  ^bb0(%in: f16, %in_2002: f32, %in_2003: f32, %out: f32):
    %3459 = arith.divf %in_2003, %cst_0 : f32
    %3460 = arith.addf %3459, %cst_14 : f32
    %3461 = math.rsqrt %3460 : f32
    %3462 = arith.extf %in : f16 to f32
    %3463 = arith.subf %3462, %in_2002 : f32
    %3464 = arith.mulf %3463, %3461 : f32
    linalg.yield %3464 : f32
  } -> tensor<2x32x20x16384xf32>
  %collapsed_1971 = tensor.collapse_shape %3409 [[0], [1, 2], [3]] : tensor<2x32x20x16384xf32> into tensor<2x640x16384xf32>
  %expanded_1972 = tensor.expand_shape %collapsed_1971 [[0], [1], [2, 3]] output_shape [2, 640, 128, 128] : tensor<2x640x16384xf32> into tensor<2x640x128x128xf32>
  %3410 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d3)>, affine_map<(d0, d1, d2, d3) -> (d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_1972, %_params.unet.up_blocks.2.resnets.1.norm1.weight, %_params.unet.up_blocks.2.resnets.1.norm1.bias : tensor<2x640x128x128xf32>, tensor<640xf16>, tensor<640xf16>) outs(%3366 : tensor<2x128x128x640xf16>) {
  ^bb0(%in: f32, %in_2002: f16, %in_2003: f16, %out: f16):
    %3459 = arith.extf %in_2002 : f16 to f32
    %3460 = arith.mulf %in, %3459 : f32
    %3461 = arith.extf %in_2003 : f16 to f32
    %3462 = arith.addf %3460, %3461 : f32
    %3463 = arith.truncf %3462 : f32 to f16
    %3464 = arith.negf %3463 : f16
    %3465 = math.exp %3464 : f16
    %3466 = arith.addf %3465, %cst_17 : f16
    %3467 = arith.divf %cst_17, %3466 : f16
    %3468 = arith.mulf %3467, %3463 : f16
    linalg.yield %3468 : f16
  } -> tensor<2x128x128x640xf16>
  %inserted_slice_1973 = tensor.insert_slice %3410 into %3369[0, 1, 1, 0] [2, 128, 128, 640] [1, 1, 1, 1] : tensor<2x128x128x640xf16> into tensor<2x130x130x640xf16>
  %3411 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%inserted_slice_1973, %__hoisted_tensor_3x3x640x320xf16 : tensor<2x130x130x640xf16>, tensor<3x3x640x320xf16>) outs(%45 : tensor<2x128x128x320xf32>) -> tensor<2x128x128x320xf32>
  %3412 = linalg.matmul_transpose_b ins(%65, %_params.unet.up_blocks.2.resnets.1.time_emb_proj.weight : tensor<2x1280xf16>, tensor<320x1280xf16>) outs(%66 : tensor<2x320xf32>) -> tensor<2x320xf32>
  %3413 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%3411, %__hoisted_tensor_320xf32_376, %3412, %__hoisted_tensor_320xf32_377 : tensor<2x128x128x320xf32>, tensor<320xf32>, tensor<2x320xf32>, tensor<320xf32>) outs(%49 : tensor<2x320x128x128xf16>) {
  ^bb0(%in: f32, %in_2002: f32, %in_2003: f32, %in_2004: f32, %out: f16):
    %3459 = arith.addf %in_2003, %in_2004 : f32
    %3460 = arith.addf %in, %in_2002 : f32
    %3461 = arith.truncf %3459 : f32 to f16
    %3462 = arith.truncf %3460 : f32 to f16
    %3463 = arith.addf %3462, %3461 : f16
    linalg.yield %3463 : f16
  } -> tensor<2x320x128x128xf16>
  util.return %3413 : tensor<2x320x128x128xf16>
}

I can reproduce the error on this. Looking more.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
Current implementation of horizontal fusion worked only for
contraction (-> truncf)? cases. To generalize it to more cases the
following changes were needed

1) Instead of looking for a `linalg.generic` with a single
   `arith.truncf`, use `isBitTruncate` utility method.

2) Use `Operation::Equivalence` to check that the contraction
   operations are "similar" and any subsequent truncation operations
   are "similar" too.

3) Instead of trying to find an insertion point based on existing
   dominance relationship between operations,
   - Always insert the horizontally fused contraction before the first
     contraction
   - Always insert the horizontally fused truncation operation before
     the first truncation operation

4) Instead of generating the fills/empty along with horizontally fused
   operations, use separate patterns to fold concats of fills and
   concats of emptys into fill and empty, respectively.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
@MaheshRavishankar MaheshRavishankar force-pushed the sdxl_quantized_horizontal_contraction_fusion branch from e187f32 to 4c80b8f Compare August 13, 2024 06:23
@MaheshRavishankar
Copy link
Contributor Author

@saienduri after this PR lands maybe the thresholds need adjusting

@MaheshRavishankar
Copy link
Contributor Author

@benvanik could you take a look again. Once this and few others land, we can move these all out of flow

@MaheshRavishankar MaheshRavishankar merged commit 9c951ca into iree-org:main Aug 13, 2024
45 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants