Skip to content

Commit

Permalink
feat(//py): New API to embed engine in new module
Browse files Browse the repository at this point in the history
Also adds tests to confirm TRT Python API intercompatiability

Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Apr 21, 2021
1 parent 3ec836e commit 88d07a9
Show file tree
Hide file tree
Showing 10 changed files with 129 additions and 18 deletions.
6 changes: 3 additions & 3 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ c10::FunctionSchema GenerateGraphSchema(
void AddEngineToGraph(
torch::jit::script::Module mod,
std::shared_ptr<torch::jit::Graph>& g,
std::string& serialized_engine) {
const std::string& serialized_engine) {
auto engine_ptr = c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name(), serialized_engine);
// Get required metadata about the engine out
auto num_io = engine_ptr->num_io;
Expand Down Expand Up @@ -173,9 +173,9 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
return new_mod;
}

torch::jit::script::Module EmbedEngineInNewModule(std::string& engine) {
torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine) {
std::ostringstream engine_id;
engine_id << reinterpret_cast<int*>(&engine);
engine_id << reinterpret_cast<const int*>(&engine);
torch::jit::script::Module new_mod("tensorrt_engine_mod_" + engine_id.str());
auto new_g = std::make_shared<torch::jit::Graph>();
AddEngineToGraph(new_mod, new_g, engine);
Expand Down
2 changes: 1 addition & 1 deletion core/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::

torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, CompileSpec cfg);

torch::jit::script::Module EmbedEngineInNewModule(std::string& engine);
torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine);

void set_device(const int gpu_id);

Expand Down
9 changes: 5 additions & 4 deletions cpp/api/include/trtorch/trtorch.h
Original file line number Diff line number Diff line change
Expand Up @@ -485,14 +485,15 @@ TRTORCH_API std::string ConvertGraphToTRTEngine(
* @brief Take a previously created TensorRT engine and embed it in
* in a TorchScript module
*
* @param engine: std::string - Precompiled serialized TensorRT engine
* @param engine: std::string - Pre-built serialized TensorRT engine
*
* Takes a prebuilt serialized TensorRT engine and embeds it in a TorchScript
* graph. Registers the engine as the forward method of the module
* Takes a pre-built serialized TensorRT engine and embeds it in a TorchScript
* module. Registers execution of the engine as the forward method of the module
* Forward is defined as: forward(Tensor[]) -> Tensor[]
*
* @return: A new module trageting a TensorRT engine
*/
TRTORCH_API torch::jit::Module EmbedEngineInNewModule(std::string& engine);
TRTORCH_API torch::jit::Module EmbedEngineInNewModule(const std::string& engine);

/**
* @brief Set gpu device id
Expand Down
2 changes: 1 addition & 1 deletion cpp/api/src/trtorch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module
return core::CompileGraph(module, to_internal_compile_spec(info));
}

torch::jit::Module EmbedEngineInNewModule(std::string& engine) {
torch::jit::Module EmbedEngineInNewModule(const std::string& engine) {
return core::EmbedEngineInNewModule(engine);
}

Expand Down
20 changes: 20 additions & 0 deletions py/trtorch/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,26 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: st
return trtorch._C.convert_graph_to_trt_engine(module._c, method_name, _parse_compile_spec(compile_spec))


def embed_engine_in_new_module(serialized_engine: bytes) -> torch.jit.ScriptModule:
"""Takes a pre-built serialized TensorRT engine and embeds it within a TorchScript module
Takes a pre-built serialied TensorRT engine (as bytes) and embeds it within a TorchScript module.
Registers the forward method to execute the TensorRT engine with the function signature:
forward(Tensor[]) -> Tensor[]
Module can be save with engine embedded with torch.jit.save and moved / loaded according to TRTorch portability rules
Args:
serialized_engine (bytes): Serialized TensorRT engine from either TRTorch or TensorRT APIs
Returns:
torch.jit.ScriptModule: New TorchScript module with engine embedded
"""
cpp_mod = trtorch._C.embed_engine_in_new_module(serialized_engine)
return torch.jit._recursive.wrap_cpp_module(cpp_mod)


def check_method_op_support(module: torch.jit.ScriptModule, method_name: str) -> bool:
"""Checks to see if a method is fully supported by TRTorch
Expand Down
8 changes: 8 additions & 0 deletions py/trtorch/csrc/trtorch_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ bool CheckMethodOperatorSupport(const torch::jit::Module& module, const std::str
return core::CheckMethodOperatorSupport(module, method_name);
}

torch::jit::Module EmbedEngineInNewModule(const py::bytes& engine) {
return core::EmbedEngineInNewModule(engine);
}

std::string get_build_info() {
auto info = core::util::get_build_info();
return info;
Expand Down Expand Up @@ -270,6 +274,10 @@ PYBIND11_MODULE(_C, m) {
"check_method_op_support",
&trtorch::pyapi::CheckMethodOperatorSupport,
"Takes a module and a method name and checks if the method graph contains purely convertable operators");
m.def(
"embed_engine_in_new_module",
&trtorch::pyapi::EmbedEngineInNewModule,
"Takes a serialized TensorRT engine and wraps it in the forward method of a new TorchScript module");
m.def("get_build_info", &get_build_info, "Returns build info about the compiler as a string");

m.def("_get_logging_prefix", &logging::get_logging_prefix, "Get the current prefix for the logging output");
Expand Down
2 changes: 1 addition & 1 deletion tests/modules/test_modules_as_engines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ TEST_P(ModuleTests, ModuleAsEngineIsClose) {
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-5));
}

TEST_P(ModuleTests, ModuleToModuleIsClose) {
TEST_P(ModuleTests, ModuleToEngineToModuleIsClose) {
std::vector<at::Tensor> inputs;
std::vector<torch::jit::IValue> inputs_ivalues;
for (auto in_shape : input_shapes) {
Expand Down
19 changes: 14 additions & 5 deletions tests/py/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ py_test(
srcs = [
"test_ptq_dataloader_calibrator.py",
"model_test_case.py"
]
],
deps = [
requirement("torchvision")
]
Expand All @@ -43,7 +43,7 @@ py_test(
srcs = [
"test_ptq_trt_calibrator.py",
"model_test_case.py"
]
],
deps = [
requirement("torchvision")
]
Expand All @@ -56,8 +56,6 @@ py_test(
"test_multi_gpu.py",
"model_test_case.py"
],
"//conditions:default" : []
}),
deps = [
requirement("torchvision")
]
Expand All @@ -74,12 +72,23 @@ py_test(
]
)

py_test(
name = "test_trt_intercompatability",
srcs = [
"test_trt_intercompatability.py",
"model_test_case.py"
],
deps = [
requirement("torchvision")
]
)

py_test(
name = "test_ptq_to_backend",
srcs = [
"test_ptq_to_backend.py",
"model_test_case.py"
]
],
deps = [
requirement("torchvision")
]
Expand Down
28 changes: 25 additions & 3 deletions tests/py/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,27 @@ def test_compile_script(self):
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
self.assertTrue(same < 2e-3)

class TestPTtoTRTtoPT(ModelTestCase):
def setUp(self):
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
self.ts_model = torch.jit.script(self.model)

def test_pt_to_trt_to_pt(self):
compile_spec = {
"input_shapes": [self.input.shape],
"device": {
"device_type": trtorch.DeviceType.GPU,
"gpu_id": 0,
"dla_core": 0,
"allow_gpu_fallback": False,
"disable_tf32": False
}
}

trt_engine = trtorch.convert_method_to_trt_engine(self.ts_model, "forward", compile_spec)
trt_mod = trtorch.embed_engine_in_new_module(trt_engine)
same = (trt_mod(self.input) - self.ts_model(self.input)).abs().max()
self.assertTrue(same < 2e-3)

class TestCheckMethodOpSupport(unittest.TestCase):

Expand All @@ -59,13 +80,13 @@ def test_check_support(self):
class TestLoggingAPIs(unittest.TestCase):

def test_logging_prefix(self):
new_prefix = "TEST"
new_prefix = "Python API Test: "
trtorch.logging.set_logging_prefix(new_prefix)
logging_prefix = trtorch.logging.get_logging_prefix()
self.assertEqual(new_prefix, logging_prefix)

def test_reportable_log_level(self):
new_level = trtorch.logging.Level.Warning
new_level = trtorch.logging.Level.Error
trtorch.logging.set_reportable_log_level(new_level)
level = trtorch.logging.get_reportable_log_level()
self.assertEqual(new_level, level)
Expand All @@ -78,10 +99,11 @@ def test_is_colored_output_on(self):

def test_suite():
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(TestLoggingAPIs))
suite.addTest(TestCompile.parametrize(TestCompile, model=models.resnet18(pretrained=True)))
suite.addTest(TestCompile.parametrize(TestCompile, model=models.mobilenet_v2(pretrained=True)))
suite.addTest(TestPTtoTRTtoPT.parametrize(TestPTtoTRTtoPT, model=models.mobilenet_v2(pretrained=True)))
suite.addTest(unittest.makeSuite(TestCheckMethodOpSupport))
suite.addTest(unittest.makeSuite(TestLoggingAPIs))

return suite

Expand Down
51 changes: 51 additions & 0 deletions tests/py/test_trt_intercompatability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import unittest
import trtorch
import torch
import torchvision.models as models
import tensorrt as trt

from model_test_case import ModelTestCase


class TestPyTorchToTRTEngine(ModelTestCase):
def setUp(self):
self.input = torch.randn((1, 3, 224, 224)).to("cuda:0")
self.ts_model = torch.jit.script(self.model)

def test_pt_to_trt(self):
compile_spec = {
"input_shapes": [self.input.shape],
"device": {
"device_type": trtorch.DeviceType.GPU,
"gpu_id": 0,
"dla_core": 0,
"allow_gpu_fallback": False,
"disable_tf32": False
}
}

trt_engine = trtorch.convert_method_to_trt_engine(self.ts_model, "forward", compile_spec)

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
with trt.Runtime(TRT_LOGGER) as rt:
engine = rt.deserialize_cuda_engine(trt_engine)
with engine.create_execution_context() as ctx:
out = torch.empty(size=tuple(engine.get_binding_shape(1))).to("cuda:0")
bindings = [self.input.contiguous().data_ptr(), out.contiguous().data_ptr()]
ctx.execute_async(batch_size=1, bindings=bindings, stream_handle=torch.cuda.current_stream(device='cuda:0').cuda_stream)
same = (out - self.ts_model(self.input)).abs().max()
self.assertTrue(same < 2e-3)

def test_suite():
suite = unittest.TestSuite()
suite.addTest(TestPyTorchToTRTEngine.parametrize(TestPyTorchToTRTEngine, model=models.resnet18(pretrained=True)))

return suite


suite = test_suite()

runner = unittest.TextTestRunner()
result = runner.run(suite)

exit(int(not result.wasSuccessful()))

0 comments on commit 88d07a9

Please sign in to comment.