Skip to content

Commit

Permalink
fix: support shape inference for add_, support non-tensor arguments f…
Browse files Browse the repository at this point in the history
…or segmented graphs

Signed-off-by: Bo Wang <wangbo1995ee@163.com>
  • Loading branch information
bowang007 committed Mar 17, 2021
1 parent 8b7919f commit 46950bb
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 41 deletions.
29 changes: 17 additions & 12 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,18 @@ void AddEngineToGraph(
g->block()->appendNode(unpack_node);

// If there are multiple output tensors from TensorRT we wrap them in a tuple
// to return
if (unpack_node->outputs().size() > 1) {
// to return, convert to tuple only when we only have 1 segmented graph
if (!engine_id && unpack_node->outputs().size() > 1) {
// Creates prim::TupleConstruct(<output tensors>) using outputs of the
// unpack node
auto return_tuple_node = g->createTuple(unpack_node->outputs());
g->block()->appendNode(return_tuple_node);
// Set the output as the produced tuple
g->registerOutput(return_tuple_node->outputs()[0]);
} else {
// Set the output as the sole output tensor
g->registerOutput(unpack_node->outputs()[0]);
for (int i = 0; i < unpack_node->outputs().size(); ++i) {
g->registerOutput(unpack_node->outputs()[i]);
}
}

LOG_DEBUG(*g << "(AddEngineToGraph)\n");
Expand Down Expand Up @@ -159,32 +160,35 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::

void AddSegmentedBlockToGraph(std::shared_ptr<torch::jit::Graph>& g, partitioning::SegmentedBlock &seg,
std::unordered_map<torch::jit::Value*, torch::jit::Value*> &old_to_new_g) {
//old_to_new_g contains: original_graph value => new graph value, mini_graph value -> new graph value, new graph value -> mini_graph value
//old_to_new_g contains: original global graph value => new global graph value,
//mini_to_new_g: mini graph value -> new graph value
std::unordered_map<torch::jit::Value*, torch::jit::Value*> mini_to_new_g;
size_t input_idx = 0;
if (seg.target() == partitioning::SegmentedBlock::kTensorRT && g->inputs().size() > 0) {
if (g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) {
auto self = g->insertInput(0, "self_1");
self->setType(seg.inputs()[0]->type());
}
old_to_new_g[seg.inputs()[input_idx++]] = g->inputs()[0];
mini_to_new_g[seg.inputs()[input_idx++]] = g->inputs()[0];
}


for (auto &raw_input : seg.raw_inputs()) {
if (old_to_new_g.count(raw_input)) {
old_to_new_g[seg.inputs()[input_idx++]] = old_to_new_g[raw_input];
mini_to_new_g[seg.inputs()[input_idx++]] = old_to_new_g[raw_input];
}
}

for (const auto n : seg.nodes()) {
partitioning::cloneNode(n, g, old_to_new_g);
partitioning::cloneNode(n, g, mini_to_new_g);
}

// original graph value => new global graph value
for (size_t i = 0; i < seg.raw_outputs().size(); ++i) {
old_to_new_g[seg.raw_outputs()[i]] = old_to_new_g[seg.outputs()[i]];
old_to_new_g[seg.raw_outputs()[i]] = mini_to_new_g[seg.outputs()[i]];
}

// LOG_INFO(*g << "(AddSegmentedBlockToGraph)\n");
LOG_INFO(*g << "(AddSegmentedBlockToGraph)\n");
return;
}

Expand All @@ -199,21 +203,21 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
if (method.name().rfind("_", 0)) {
auto new_g = std::make_shared<torch::jit::Graph>();
auto graph_and_parameters = lowering::Lower(mod, method.name());
LOG_INFO(*(method.graph()) << "Original grpah\n");

auto g = graph_and_parameters.first;
auto params = graph_and_parameters.second;
auto named_params = conversion::get_named_params(g->inputs(), params);
auto convert_cfg = std::move(cfg.convert_info);
LOG_INFO(*g << "(CompileGraph)\n");


// segment the graph and convert segmented TensorRT block
auto segmented_blocks = partitioning::segment_graph(g, convert_cfg.input_ranges, convert_cfg.engine_settings.torch_fallback);
if (segmented_blocks.size() == 1 && segmented_blocks[0].target() == partitioning::SegmentedBlock::kTorch) {
return mod;
}

int trt_engine_id = 0;
int trt_engine_id = 1;
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
for (auto &seg_block : segmented_blocks) {
if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
Expand All @@ -225,6 +229,7 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, named_params);
auto temp_g = std::make_shared<torch::jit::Graph>();
AddEngineToGraph(new_mod, temp_g, engine, trt_engine_id++);

seg_block.update_graph(temp_g);
AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
} else {
Expand Down
3 changes: 2 additions & 1 deletion core/lowering/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ cc_library(
"unpack_batch_norm.cpp",
"unpack_log_softmax.cpp",
"op_aliasing.cpp",
"silu_to_sigmoid_multiplication.cpp"
"silu_to_sigmoid_multiplication.cpp",
"remove_inplace_add.cpp"
],
deps = [
"//core/util:prelude",
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph);
void AliasOperators(std::shared_ptr<torch::jit::Graph>& graph);
void SiluToSigmoidMultipication(std::shared_ptr<torch::jit::Graph>& graph);
void RemoveInplaceAdd(std::shared_ptr<torch::jit::Graph>& graph);

} // namespace passes
} // namespace lowering
Expand Down
30 changes: 30 additions & 0 deletions core/lowering/passes/remove_inplace_add.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#include <torch/csrc/jit/passes/subgraph_rewrite.h>

#include "core/util/prelude.h"

namespace trtorch {
namespace core {
namespace lowering {
namespace passes {

void RemoveInplaceAdd(std::shared_ptr<torch::jit::Graph>& graph) {
std::string inplace_add_pattern = R"IR(
graph(%self, %other, %1):
%out = aten::add_(%self, %other, %1)
return (%out))IR";
std::string normal_add_pattern = R"IR(
graph(%self, %other, %1):
%out = aten::add(%self, %other, %1)
return (%out))IR";

torch::jit::SubgraphRewriter remove_inplace_add;
remove_inplace_add.RegisterRewritePattern(inplace_add_pattern, normal_add_pattern);
remove_inplace_add.runOnGraph(graph);

LOG_GRAPH("Post remove inplace add: " << *graph);
}

} // namespace passes
} // namespace lowering
} // namespace core
} // namespace trtorch
3 changes: 2 additions & 1 deletion core/partitioning/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ cc_library(
],
deps = [
"//core/conversion",
"//core/util:prelude"
"//core/util:prelude",
"//core/lowering"
] + select({
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
"//conditions:default": ["@libtorch//:libtorch"],
Expand Down
69 changes: 45 additions & 24 deletions core/partitioning/partitioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#include "core/util/prelude.h"
#include "torch/csrc/jit/api/module.h"
#include "core/util/prelude.h"
#include "core/lowering/passes/passes.h"



namespace trtorch {
Expand All @@ -20,9 +22,9 @@ torch::jit::Value* getOrAddInputForValue(torch::jit::Value* old_value, std::shar
}
auto new_value = graph->block()->addInput();
old_to_new[old_value] = new_value;
new_value->copyMetadata(old_value);
// mapping from new graph input Values to original graph values
old_to_new[new_value] = old_value;
new_value->copyMetadata(old_value);
return new_value;
} else {
return old_to_new[old_value];
Expand All @@ -40,7 +42,6 @@ torch::jit::Node* cloneNode(torch::jit::Node* node, std::shared_ptr<torch::jit::
auto no = new_node->outputs()[i];
old_to_new[oo] = no;
}

return new_node;
}

Expand All @@ -58,10 +59,13 @@ c10::FunctionSchema getFunctionSchema(std::string method_name, std::shared_ptr<t
return c10::FunctionSchema(method_name, method_name, args, returns);
}

void registerSegmentInOutShape(SegmentedBlock &seg_block, std::unordered_map<torch::jit::Value*, nvinfer1::Dims> &input_shape_map) {
void registerSegmentInOutIValues(SegmentedBlock &seg_block, std::unordered_map<torch::jit::Value*, torch::jit::IValue> &ivalues_maps) {
// create a module to run the graph
auto g = seg_block.g();
auto copy_g = g->copy();
lowering::passes::RemoveInplaceAdd(copy_g);

// create tuple for multiple outputs
if (seg_block.raw_outputs().size() > 1) {
auto new_output_node = copy_g->appendNode(copy_g->createTuple(copy_g->outputs()));
for (int idx = copy_g->outputs().size() - 1; idx >= 0; --idx) {
Expand All @@ -84,46 +88,60 @@ void registerSegmentInOutShape(SegmentedBlock &seg_block, std::unordered_map<tor

// set inputs ivalues
for (auto &input : seg_block.raw_inputs()) {
std::vector<int64_t> shape;
nvinfer1::Dims cur_shape = input_shape_map[input];
shape.insert(shape.begin(), std::begin(cur_shape.d), std::begin(cur_shape.d) + cur_shape.nbDims);
auto in = at::randint(5, shape, {at::kCUDA});
jit_inputs_ivalues.push_back(in.clone());
if (!ivalues_maps.count(input)) {
std::cerr << "could find graph input ivalues\n";
}
if (input->type()->isSubtypeOf(torch::jit::TensorType::get())) {
jit_inputs_ivalues.push_back(ivalues_maps[input].toTensor());
} else if (input->type()->isSubtypeOf(torch::jit::IntType::get())) {
jit_inputs_ivalues.push_back(ivalues_maps[input].toInt());
}
}

std::vector<at::Tensor> jit_results;
std::vector<torch::jit::IValue> jit_results;
torch::jit::IValue jit_results_ivalues = cur_mod.forward(jit_inputs_ivalues);
if (jit_results_ivalues.isTensor()) {
jit_results.push_back(jit_results_ivalues.toTensor());
} else {
if (jit_results_ivalues.isTuple()) {
auto results = jit_results_ivalues.toTuple()->elements();
for (auto r : results) {
jit_results.push_back(r.toTensor());
jit_results.push_back(r);
}
} else {
jit_results.push_back(jit_results_ivalues);
}

size_t idx = 0;
for (auto &output : seg_block.raw_outputs()) {
input_shape_map[output] = util::toDims(jit_results[idx++].sizes());
ivalues_maps[output] = jit_results[idx++];
}

// set input shape for each segmented block so we wil use it in conversion process
std::vector<nvinfer1::Dims> input_shape;
for (auto &i : seg_block.raw_inputs()) {
input_shape.push_back(input_shape_map[i]);
if (ivalues_maps[i].isTensor()) {
input_shape.push_back(util::toDims(ivalues_maps[i].toTensor().sizes()));
}
}

seg_block.register_inshape(input_shape);
}

std::vector<nvinfer1::Dims> extractNvinfer1Dims(std::vector<conversion::InputRange>& input_ranges) {
std::vector<nvinfer1::Dims> res;

std::vector<torch::jit::IValue> generateRandomInputs(std::vector<conversion::InputRange>& input_ranges) {
std::vector<torch::jit::IValue> random_inputs;
for (auto &input_range : input_ranges) {
res.push_back(input_range.input_shape);
auto cur_shape = input_range.input_shape;
std::vector<int64_t> shape;
shape.insert(shape.begin(), std::begin(cur_shape.d), std::begin(cur_shape.d) + cur_shape.nbDims);
auto in = at::randint(5, shape, {at::kCUDA});
random_inputs.push_back(in.clone());
printf("is tensor: %d\n", random_inputs.back().isTensor());
}
return res;
return random_inputs;
}


void registerSegmentsInputsOutputs(std::vector<SegmentedBlock> &segmented_blocks, std::shared_ptr<torch::jit::Graph> g) {
// find the corresponding raw values in original global graph for this segmented block's inputs/outputs
std::set<torch::jit::Value*> input_values;
for (auto &seg_block : segmented_blocks) {
seg_block.registerInputs();
Expand Down Expand Up @@ -176,6 +194,7 @@ std::vector<SegmentedBlock> segment_graph(std::shared_ptr<torch::jit::Graph> g,

for (const auto n : nodes) {
if (n->kind() == torch::jit::prim::Constant) continue;

std::string node_string(n->kind().toQualString());

if (conversion::OpSupported(n) && !forced_fallback_operators.count(node_string)) {
Expand All @@ -186,19 +205,21 @@ std::vector<SegmentedBlock> segment_graph(std::shared_ptr<torch::jit::Graph> g,
}
}
merge_nodes(pytorch_nodes, tensorrt_nodes, segmented_blocks, min_block_size);
if (!pytorch_nodes.empty()) segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
if (!pytorch_nodes.empty()) {
segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
}

registerSegmentsInputsOutputs(segmented_blocks, g);

std::vector<nvinfer1::Dims> graph_inputs_shape = extractNvinfer1Dims(input_ranges);
std::unordered_map<torch::jit::Value*, nvinfer1::Dims> input_shape_map;
std::unordered_map<torch::jit::Value*, torch::jit::IValue> ivalues_maps;

std::vector<torch::jit::IValue> random_inputs = generateRandomInputs(input_ranges);
for (size_t i = 0; i < g->inputs().size(); ++i) {
input_shape_map[g->inputs()[i]] = graph_inputs_shape[i];
ivalues_maps[g->inputs()[i]] = random_inputs[i];
}

for (auto &seg_block : segmented_blocks) {
registerSegmentInOutShape(seg_block, input_shape_map);
registerSegmentInOutIValues(seg_block, ivalues_maps);
}

return segmented_blocks;
Expand Down
9 changes: 6 additions & 3 deletions core/partitioning/partitioning.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ namespace trtorch {
namespace core {
namespace partitioning {

torch::jit::Value* getOrAddInputForValue(torch::jit::Value* old_value, std::shared_ptr<torch::jit::Graph> &graph,
std::unordered_map<torch::jit::Value*, torch::jit::Value*> &old_to_new);

torch::jit::Node* cloneNode(torch::jit::Node* node, std::shared_ptr<torch::jit::Graph> &graph,
std::unordered_map<torch::jit::Value*, torch::jit::Value*> &old_to_new);

Expand Down Expand Up @@ -49,7 +52,6 @@ struct SegmentedBlock {

void registerOutput(torch::jit::Value* raw_input) {
outputs_.push_back(raw_input);

g_->registerOutput(old_to_new_[raw_input]);
}

Expand Down Expand Up @@ -97,15 +99,16 @@ struct SegmentedBlock {
return out_shape_;
}

const std::shared_ptr<torch::jit::Graph>& g() const {
std::shared_ptr<torch::jit::Graph>& g() {
return g_;
}


void update_graph(std::shared_ptr<torch::jit::Graph> new_g) {
g_ = new_g;
}

private:
// private:
SegmentedBlockTarget target_;
std::vector<nvinfer1::Dims> in_shape_;
std::vector<nvinfer1::Dims> out_shape_;
Expand Down

0 comments on commit 46950bb

Please sign in to comment.