Skip to content

Commit

Permalink
[TorchToTcp] Make tensor operand checks in `aten.convolution -> tcp.c…
Browse files Browse the repository at this point in the history
…ustom_op` lit test explicit (cruise-automation#28)

As titled.
  • Loading branch information
sjain-stanford committed Jan 16, 2024
1 parent 85c1e73 commit 90768ec
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 14 deletions.
1 change: 0 additions & 1 deletion lib/Conversion/TorchToTcp/TcpCustomOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"

#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Debug.h"

using namespace mlir;
using namespace mlir::tcp;
Expand Down
33 changes: 20 additions & 13 deletions test/Conversion/TorchToTcp/tcp_custom_ops.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: tcp-opt <%s -convert-torch-to-tcp-custom-op -canonicalize -split-input-file -verify-diagnostics | FileCheck %s
// RUN: tcp-opt <%s -convert-torch-to-tcp-custom-op -canonicalize -split-input-file | FileCheck %s

// CHECK-LABEL: func.func @torch.aten.gather_op(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,2],si64>
Expand Down Expand Up @@ -58,18 +58,27 @@ func.func @torch.aten.index_put_impl_op(%arg0: !torch.vtensor<[25],f32>, %arg1:
return %1 : !torch.vtensor<[25],f32>
}


// -----

// CHECK: tcp.custom_op("torch.aten.convolution") %{{.*}}, %{{.*}}, %{{.*}} {
// CHECK-SAME: dilation = [1 : index, 1 : index],
// CHECK-SAME: groups = 1 : i64,
// CHECK-SAME: output_padding = [1 : index, 1 : index],
// CHECK-SAME: padding = [1 : index, 1 : index],
// CHECK-SAME: stride = [2 : index, 2 : index],
// CHECK-SAME: torch_operand_names = ["input", "weight", "bias"],
// CHECK-SAME: transposed = true} : tensor<1x64x1x100xf32>, tensor<64x64x3x3xf32>, tensor<64xf32> -> tensor<1x64x2x200xf32>
func.func @torcn.aten.transposed_convolution(%input: !torch.vtensor<[1,64,1,100],f32>) -> !torch.vtensor<[1,64,2,200],f32> {
// CHECK-LABEL: func.func @torch.aten.transposed_convolution(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,64,1,100],f32>) -> !torch.vtensor<[1,64,2,200],f32>
// CHECK: %[[T0:.*]] = torch.vtensor.literal(dense<0.000000e+00> : tensor<64xf32>) : !torch.vtensor<[64],f32>
// CHECK: %[[T1:.*]] = torch.vtensor.literal(dense<0.000000e+00> : tensor<64x64x3x3xf32>) : !torch.vtensor<[64,64,3,3],f32>
// CHECK: %[[T2:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1,64,1,100],f32> -> tensor<1x64x1x100xf32>
// CHECK: %[[T3:.*]] = torch_c.to_builtin_tensor %[[T1]] : !torch.vtensor<[64,64,3,3],f32> -> tensor<64x64x3x3xf32>
// CHECK: %[[T4:.*]] = torch_c.to_builtin_tensor %[[T0]] : !torch.vtensor<[64],f32> -> tensor<64xf32>
// CHECK: %[[CUSTOM:.*]] = tcp.custom_op("torch.aten.convolution") %[[T2]], %[[T3]], %[[T4]] {
// CHECK-SAME: dilation = [1 : index, 1 : index],
// CHECK-SAME: groups = 1 : i64,
// CHECK-SAME: output_padding = [1 : index, 1 : index],
// CHECK-SAME: padding = [1 : index, 1 : index],
// CHECK-SAME: stride = [2 : index, 2 : index],
// CHECK-SAME: torch_operand_names = ["input", "weight", "bias"],
// CHECK-SAME: transposed = true}
// CHECK-SAME: tensor<1x64x1x100xf32>, tensor<64x64x3x3xf32>, tensor<64xf32> -> tensor<1x64x2x200xf32>
// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[CUSTOM]] : tensor<1x64x2x200xf32> -> !torch.vtensor<[1,64,2,200],f32>
// CHECK: return %[[RES]] : !torch.vtensor<[1,64,2,200],f32>
func.func @torch.aten.transposed_convolution(%input: !torch.vtensor<[1,64,1,100],f32>) -> !torch.vtensor<[1,64,2,200],f32> {
%true = torch.constant.bool true
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
Expand All @@ -78,7 +87,6 @@ func.func @torcn.aten.transposed_convolution(%input: !torch.vtensor<[1,64,1,100]
%stride = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%int1x1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%output = torch.aten.convolution %input, %weight, %bias, %stride, %int1x1, %int1x1, %true, %int1x1, %int1 : !torch.vtensor<[1,64,1,100],f32>, !torch.vtensor<[64,64,3,3],f32>, !torch.vtensor<[64],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,64,2,200],f32>

return %output : !torch.vtensor<[1,64,2,200],f32>
}

Expand All @@ -95,6 +103,5 @@ func.func @torch.aten.regular_convolution() -> !torch.vtensor<[1,32,16,1600],f32
%int1x1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%none = torch.constant.none
%output = torch.aten.convolution %input, %weights, %none, %int1x1, %int1x1, %int1x1, %false, %int0x0, %int1 : !torch.vtensor<[1,9,16,1600],f32>, !torch.vtensor<[32,9,3,3],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,32,16,1600],f32>

return %output : !torch.vtensor<[1,32,16,1600],f32>
}

0 comments on commit 90768ec

Please sign in to comment.