Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: int/int=float division #1957

Merged
merged 1 commit into from
May 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 30 additions & 34 deletions core/conversion/converters/impl/element_wise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,31 @@ nvinfer1::ITensor* clamp_util(
return clamp_layer_out;
}

void cast_int_int_div_tensors(
ConversionCtx* ctx,
const torch::jit::Node* n,
nvinfer1::ITensor*& a,
nvinfer1::ITensor*& b) {
// Torch automatically produces a float for int/int division
if (a->getType() == nvinfer1::DataType::kINT32 && b->getType() == nvinfer1::DataType::kINT32) {
a = castITensor(ctx, a, nvinfer1::DataType::kFLOAT, util::node_info(n) + "_a_cast");
b = castITensor(ctx, b, nvinfer1::DataType::kFLOAT, util::node_info(n) + "_b_cast");
}
}

bool element_wise_divide_implementation(
ConversionCtx* ctx,
const torch::jit::Node* n,
nvinfer1::ITensor* a,
nvinfer1::ITensor* b) {
cast_int_int_div_tensors(ctx, n, a, b);
auto element_wise = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, a, b, util::node_info(n));
TORCHTRT_CHECK(element_wise, "Unable to create element_wise layer from node: " << *n);
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], element_wise->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
}

auto element_wise_registrations TORCHTRT_UNUSED =
RegisterNodeConversionPatterns()
.pattern(
Expand Down Expand Up @@ -296,18 +321,9 @@ auto element_wise_registrations TORCHTRT_UNUSED =
.pattern(
{"aten::div.Tensor(Tensor self, Tensor other) -> Tensor",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// Should implement self / other
auto self = args[0].ITensorOrFreeze(ctx);
auto other = args[1].ITensorOrFreeze(ctx);
auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n));

TORCHTRT_CHECK(div, "Unable to create div layer from node: " << *n);

div->setName(util::node_info(n).c_str());
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], div->getOutput(0));

LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
return element_wise_divide_implementation(ctx, n, self, other);
}})
.pattern(
{"aten::div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> (Tensor)",
Expand Down Expand Up @@ -349,6 +365,7 @@ auto element_wise_registrations TORCHTRT_UNUSED =
div = add_elementwise(
ctx, nvinfer1::ElementWiseOperation::kPROD, floor, sign->getOutput(0), util::node_info(n));
} else {
cast_int_int_div_tensors(ctx, n, self, other);
div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n));
}

Expand All @@ -365,42 +382,21 @@ auto element_wise_registrations TORCHTRT_UNUSED =
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto self = args[0].ITensorOrFreeze(ctx);
auto other = scalar_to_tensor(ctx, args[1].unwrapToScalar());
auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n));
TORCHTRT_CHECK(div, "Unable to create div layer from node: " << *n);

div->setName(util::node_info(n).c_str());
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], div->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
return element_wise_divide_implementation(ctx, n, self, other);
}})
.pattern(
{"aten::div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// TODO: Remove with functionalization
auto self = args[0].ITensorOrFreeze(ctx);
auto other = args[1].ITensorOrFreeze(ctx);
auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n));

TORCHTRT_CHECK(div, "Unable to create div layer from node: " << *n);

div->setName(util::node_info(n).c_str());
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], div->getOutput(0));

LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
return element_wise_divide_implementation(ctx, n, self, other);
}})
.pattern(
{"aten::div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto self = args[0].ITensorOrFreeze(ctx);
auto other = scalar_to_tensor(ctx, args[1].unwrapToScalar());
auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n));
TORCHTRT_CHECK(div, "Unable to create div layer from node: " << *n);

div->setName(util::node_info(n).c_str());
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], div->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
return element_wise_divide_implementation(ctx, n, self, other);
}})
.pattern(
{"aten::square(Tensor self) -> Tensor",
Expand Down
15 changes: 15 additions & 0 deletions tests/core/conversion/converters/test_div.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ TEST(Converters, ATenDivConvertsCorrectly) {
pointwise_test_helper(graph, false, false, {4}, {3, 4});
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kInt);
}

TEST(Converters, ATenDivWithScalarConvertsCorrectly) {
Expand All @@ -29,6 +30,16 @@ TEST(Converters, ATenDivWithScalarConvertsCorrectly) {
pointwise_test_helper(graph, true);
}

TEST(Converters, ATenDivWithScalarIntConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
%scalar : int = prim::Constant[value=2]()
%1 : Tensor = aten::div(%0, %scalar)
return (%1))IR";
pointwise_test_helper(graph, true);
pointwise_test_helper(graph, true, false, {5}, {1}, false, at::kInt);
}

TEST(Converters, ATenDivRoundingFloorConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor, %1 : Tensor):
Expand All @@ -42,6 +53,7 @@ TEST(Converters, ATenDivRoundingFloorConvertsCorrectly) {
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}, true);
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kFloat, at::kInt);
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kFloat);
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kInt);
}

TEST(Converters, ATenDivRoundingTruncConvertsCorrectly) {
Expand All @@ -57,6 +69,7 @@ TEST(Converters, ATenDivRoundingTruncConvertsCorrectly) {
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}, true);
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kFloat, at::kInt);
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kFloat);
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kInt);
}

TEST(Converters, ATenDivRoundingNoneConvertsCorrectly) {
Expand All @@ -70,6 +83,7 @@ TEST(Converters, ATenDivRoundingNoneConvertsCorrectly) {
pointwise_test_helper(graph, false, false, {4}, {3, 4}, true);
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3}, true);
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}, true);
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kInt);
}

TEST(Converters, ATenDivRoundingTruncWithIntsConvertsCorrectly) {
Expand Down Expand Up @@ -107,6 +121,7 @@ TEST(Converters, ATenFloorDivideConvertsCorrectly) {
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kFloat, at::kInt);
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kFloat);
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kInt);
}

TEST(Converters, ATenFloorDivideWithScalarConvertsCorrectly) {
Expand Down