Skip to content

Commit

Permalink
fix: Add lowering pass to remove output repacking in `convert_method_…
Browse files Browse the repository at this point in the history
…to_trt_engine` calls (#1945)
  • Loading branch information
gs-olive authored and narendasan committed Jun 3, 2023
1 parent 8b6686a commit e950f1c
Show file tree
Hide file tree
Showing 10 changed files with 145 additions and 6 deletions.
3 changes: 3 additions & 0 deletions core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::I
passes::RemoveSingleUse0DTensors(g);
passes::RemoveUnnecessaryCasts(g);
passes::ReplaceAtenInt(g);
if (lower_info.converting_to_trt_engine) {
passes::RemoveCollectionCast(g);
}
passes::UnpackAndCastMaskedFill(g, lower_info.getGPUDeviceString());
passes::UnpackAndCastNumToTensor(g, lower_info.getGPUDeviceString());
passes::UnpackAndCastFull(g, lower_info.getGPUDeviceString());
Expand Down
4 changes: 4 additions & 0 deletions core/lowering/lowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ struct LowerInfo {
// Since these QDQ nodes will be identical as they share same input, one of them is eliminated due to CSE lowering
// pass. Disable this in order to not disturb TensorRT's QAT optimizations.
bool disable_cse = false;

// Whether the originating caller is `convert_method_to_trt_engine` (true) or `compile` (false)
bool converting_to_trt_engine = false;

ir::Device target_device;
std::vector<std::string> forced_fallback_modules;
friend std::ostream& operator<<(std::ostream& os, const LowerInfo& l);
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ void RemoveNOPs(std::shared_ptr<torch::jit::Graph> graph);
void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g);
void RemoveUnnecessaryCasts(std::shared_ptr<torch::jit::Graph>& graph);
void ReplaceAtenInt(std::shared_ptr<torch::jit::Graph>& g);
void RemoveCollectionCast(std::shared_ptr<torch::jit::Graph>& g);
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph);
Expand Down
39 changes: 39 additions & 0 deletions core/lowering/passes/remove_unnecessary_casts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,45 @@ void ReplaceAtenInt(std::shared_ptr<torch::jit::Graph>& g) {
LOG_GRAPH("Post removing aten.Int.Tensor operations: " << *g);
}

void RemoveCollectionCast(std::shared_ptr<torch::jit::Graph>& g) {
// Removes unnecessary collection-casting of graph outputs
// Only to be used if the overall output is intended to be a TRT Engine
// Will cause errors if used directly as a TorchScript graph

// Validate the output is a single value with type Tuple or List
if (!(g->outputs().size() == 1 &&
(g->outputs()[0]->node()->kind() == torch::jit::prim::TupleConstruct ||
g->outputs()[0]->node()->kind() == torch::jit::prim::ListConstruct))) {
return;
}

// Ensure all inputs to the Tuple/List Construct operator are regular Tensors
// (nested structures cannot be preserved in TensorRT)
auto all_tensors = true;
auto collection_inputs = g->outputs()[0]->node()->inputs();

for (size_t i = 0; i < collection_inputs.size(); ++i) {
all_tensors &= collection_inputs[i]->type()->isSubtypeOf(c10::TensorType::get());
}

if (!all_tensors) {
return;
}

// For each input to the collection packing operator, add its value directly
// as an output of the graph
for (size_t i = 0; i < collection_inputs.size(); ++i) {
g->registerOutput(collection_inputs[i]);
}

// Remove the original output value of the graph (the collection object)
g->eraseOutput(0);

// Clean up remnant collection node in graph
torch::jit::EliminateDeadCode(g);
LOG_GRAPH("Post removing collection casting operations: " << *g);
}

} // namespace passes
} // namespace lowering
} // namespace core
Expand Down
4 changes: 3 additions & 1 deletion cpp/src/compile_spec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,11 @@ torchtrt::core::CompileSpec init_compile_spec(CompileSpec& external) {
}
}

torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) {
torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external, bool converting_to_trt_engine) {
torchtrt::core::CompileSpec internal = init_compile_spec(external);

internal.lower_info.converting_to_trt_engine = converting_to_trt_engine;

for (auto p : external.enabled_precisions) {
internal.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(p));
}
Expand Down
5 changes: 3 additions & 2 deletions cpp/src/torch_tensorrt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace torch_tensorrt {
torch_tensorrt::core::runtime::RTDevice to_internal_rt_device(Device device);
namespace torchscript {
// Defined in compile_spec.cpp
torch_tensorrt::core::CompileSpec to_internal_compile_spec(CompileSpec external);
torch_tensorrt::core::CompileSpec to_internal_compile_spec(CompileSpec external, bool converting_to_trt_engine = false);

bool check_method_operator_support(const torch::jit::script::Module& module, std::string method_name) {
return torch_tensorrt::core::CheckMethodOperatorSupport(module, method_name);
Expand All @@ -23,7 +23,8 @@ std::string convert_method_to_trt_engine(
LOG_DEBUG(get_build_info());
// Want to export a much simpler (non TRT header dependent) API so doing the
// type conversion here
return torch_tensorrt::core::ConvertGraphToTRTEngine(module, method_name, to_internal_compile_spec(info));
return torch_tensorrt::core::ConvertGraphToTRTEngine(
module, method_name, to_internal_compile_spec(info, /*bool converting_to_trt_engine=*/true));
}

torch::jit::script::Module compile(const torch::jit::script::Module& module, CompileSpec info) {
Expand Down
4 changes: 3 additions & 1 deletion py/torch_tensorrt/csrc/tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,9 +326,11 @@ core::CompileSpec init_compile_spec(CompileSpec external) {
}
}

core::CompileSpec CompileSpec::toInternalCompileSpec() {
core::CompileSpec CompileSpec::toInternalCompileSpec(bool converting_to_trt_engine) {
core::CompileSpec info = init_compile_spec(*this);

info.lower_info.converting_to_trt_engine = converting_to_trt_engine;

for (auto p : enabled_precisions) {
info.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(p));
}
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/csrc/tensorrt_classes.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ std::string to_str(EngineCapability value);
nvinfer1::EngineCapability toTRTEngineCapability(EngineCapability value);

struct CompileSpec : torch::CustomClassHolder {
core::CompileSpec toInternalCompileSpec();
core::CompileSpec toInternalCompileSpec(bool converting_to_trt_engine = false);
std::string stringify();
void appendInput(const c10::intrusive_ptr<Input>& ir) {
inputs.push_back(*ir);
Expand Down
3 changes: 2 additions & 1 deletion py/torch_tensorrt/csrc/torch_tensorrt_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec& info

py::bytes ConvertGraphToTRTEngine(const torch::jit::Module& mod, const std::string& method_name, CompileSpec& info) {
py::gil_scoped_acquire gil;
auto trt_engine = core::ConvertGraphToTRTEngine(mod, method_name, info.toInternalCompileSpec());
auto trt_engine = core::ConvertGraphToTRTEngine(
mod, method_name, info.toInternalCompileSpec(/*bool converting_to_trt_engine=*/true));
return py::bytes(trt_engine);
}

Expand Down
86 changes: 86 additions & 0 deletions tests/core/lowering/test_remove_unnecessary_casts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -589,3 +589,89 @@ TEST(LoweringPasses, RemoveAtenIntConstTensorValuesAgree) {
// Validate identical graphs after pooling constants and canonicalizing
ASSERT_TRUE((tg->toString() == sg->toString()));
}

TEST(LoweringPasses, RemoveCollectionCastTuple) {
// Ensure the lowering pass transforms the first graph into the second
std::string source_graph = R"IR(
graph(%x.1 : Tensor):
%3 : int = prim::Constant[value=1]()
%2 : int = prim::Constant[value=2]()
%a.1 : Tensor = aten::mul(%x.1, %2)
%b.1 : Tensor = aten::add(%a.1, %2, %3)
%c.1 : Tensor = aten::relu(%b.1)
%d.1 : Tensor = aten::sqrt(%c.1)
%8 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%c.1, %d.1, %b.1)
return (%8))IR";

std::string target_graph = R"IR(
graph(%x.1 : Tensor):
%3 : int = prim::Constant[value=1]()
%2 : int = prim::Constant[value=2]()
%a.1 : Tensor = aten::mul(%x.1, %2)
%b.1 : Tensor = aten::add(%a.1, %2, %3)
%c.1 : Tensor = aten::relu(%b.1)
%d.1 : Tensor = aten::sqrt(%c.1)
return (%c.1, %d.1, %b.1))IR";

// Ensure the lowering pass transforms the first graph into the second
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, sg.get());

torch_tensorrt::core::lowering::passes::RemoveCollectionCast(sg);
torch::jit::ConstantPooling(sg);
sg = torch::jit::Canonicalize(sg, false);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, tg.get());

torch::jit::ConstantPooling(tg);
tg = torch::jit::Canonicalize(tg, false);

// Validate identical graphs after pooling constants and canonicalizing
ASSERT_TRUE((tg->toString() == sg->toString()));
}

TEST(LoweringPasses, RemoveCollectionCastList) {
// Ensure the lowering pass transforms the first graph into the second
std::string source_graph = R"IR(
graph(%x.1 : Tensor):
%3 : int = prim::Constant[value=1]()
%2 : int = prim::Constant[value=2]()
%a.1 : Tensor = aten::mul(%x.1, %2)
%b.1 : Tensor = aten::add(%a.1, %2, %3)
%c.1 : Tensor = aten::relu(%b.1)
%d.1 : Tensor = aten::sqrt(%c.1)
%8 : (Tensor, Tensor, Tensor) = prim::ListConstruct(%b.1, %c.1, %d.1)
return (%8))IR";

std::string target_graph = R"IR(
graph(%x.1 : Tensor):
%3 : int = prim::Constant[value=1]()
%2 : int = prim::Constant[value=2]()
%a.1 : Tensor = aten::mul(%x.1, %2)
%b.1 : Tensor = aten::add(%a.1, %2, %3)
%c.1 : Tensor = aten::relu(%b.1)
%d.1 : Tensor = aten::sqrt(%c.1)
return (%b.1, %c.1, %d.1))IR";

// Ensure the lowering pass transforms the first graph into the second
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, sg.get());

torch_tensorrt::core::lowering::passes::RemoveCollectionCast(sg);
torch::jit::ConstantPooling(sg);
sg = torch::jit::Canonicalize(sg, false);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, tg.get());

torch::jit::ConstantPooling(tg);
tg = torch::jit::Canonicalize(tg, false);

// Validate identical graphs after pooling constants and canonicalizing
ASSERT_TRUE((tg->toString() == sg->toString()));
}

0 comments on commit e950f1c

Please sign in to comment.