Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the internal codebase from fx2trt_oss to torch_tensorrt #1104

Merged
merged 1 commit into from
Jun 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/_modules/torch_tensorrt/_compile.html
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ <h1>Source code for torch_tensorrt._compile</h1><div class="highlight"><pre>
<span class="c1"># profiling_verbosity=trt.ProfilingVerbosity.DETAILED, #For profile</span>
<span class="p">)</span>
<span class="c1"># For profile</span>
<span class="c1"># from fx2trt_oss.fx.tools.trt_profiler_sorted import profile_trt_module</span>
<span class="c1"># from torch_tensorrt.fx.tools.trt_profiler_sorted import profile_trt_module</span>
<span class="c1"># profile_trt_module(&quot;&quot;, trt_mod, acc_inputs)</span>
<span class="n">trt_mod</span> <span class="o">=</span> <span class="n">TRTModule</span><span class="p">(</span><span class="o">*</span><span class="n">r</span><span class="p">)</span>

Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def get_input(self, inputs):
# profiling_verbosity=trt.ProfilingVerbosity.DETAILED, #For profile
)
# For profile
# from fx2trt_oss.fx.tools.trt_profiler_sorted import profile_trt_module
# from torch_tensorrt.fx.tools.trt_profiler_sorted import profile_trt_module
# profile_trt_module("", trt_mod, acc_inputs)
trt_mod = TRTModule(*r)

Expand Down
11 changes: 5 additions & 6 deletions py/torch_tensorrt/fx/converters/acc_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,21 @@
import warnings
from typing import cast, Dict, Optional, Sequence, Tuple, Union

from ..tracer.acc_tracer import acc_ops
import numpy as np

# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch

from ..converter_registry import tensorrt_converter

from ..tracer.acc_tracer import acc_ops
from ..types import * # noqa: F403
from ..utils import (
get_dynamic_dims,
torch_dtype_from_trt,
torch_dtype_to_trt,
)
from torch.fx.immutable_collections import immutable_list
from torch.fx.node import Argument, Target

from ..utils import get_dynamic_dims, torch_dtype_from_trt, torch_dtype_to_trt

from .converter_utils import * # noqa: F403


Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/fx/converters/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch

from ..converter_registry import tensorrt_converter

from .converter_utils import mark_as_int8_layer
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/fx/converters/adaptive_avgpool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch

from ..converter_registry import tensorrt_converter

from .converter_utils import extend_mod_attr_to_tuple, mark_as_int8_layer
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/fx/converters/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch

from ..converter_registry import tensorrt_converter

from .converter_utils import get_dyn_range, mark_as_int8_layer
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/fx/converters/batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch

from ..converter_registry import tensorrt_converter

from .converter_utils import get_dyn_range, mark_as_int8_layer, to_numpy
Expand Down
3 changes: 2 additions & 1 deletion py/torch_tensorrt/fx/converters/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch
from torch.fx.node import Argument, Target

from ..types import (
Shape,
TRTDataType,
Expand All @@ -18,7 +20,6 @@
TRTTensor,
)
from ..utils import torch_dtype_from_trt
from torch.fx.node import Argument, Target


def get_trt_plugin(
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/fx/converters/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import tensorrt as trt
import torch

from ..converter_registry import tensorrt_converter

from .converter_utils import (
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/fx/converters/linear.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch

from ..converter_registry import tensorrt_converter

from .converter_utils import get_dyn_range, mark_as_int8_layer, to_numpy
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/fx/converters/maxpool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch

from ..converter_registry import tensorrt_converter

from .converter_utils import extend_mod_attr_to_tuple, mark_as_int8_layer
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/fx/converters/mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch

from ..converter_registry import tensorrt_converter

from .converter_utils import get_dyn_range, mark_as_int8_layer
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/fx/converters/quantization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch

from ..converter_registry import tensorrt_converter

from .converter_utils import get_dyn_range, get_inputs_from_args_and_kwargs
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/fx/converters/transformation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch

from ..converter_registry import tensorrt_converter

from .converter_utils import mark_as_int8_layer
Expand Down
12 changes: 6 additions & 6 deletions py/torch_tensorrt/fx/example/fx2trt_example.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# type: ignore[]

import fx2trt_oss.tracer.acc_tracer.acc_tracer as acc_tracer
import torch
import torch.fx
import torch.nn as nn
from fx2trt_oss.fx import InputTensorSpec, TRTInterpreter, TRTModule
from fx2trt_oss.fx.tools.trt_splitter import TRTSplitter
import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer
from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter, TRTModule
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter


# The purpose of this example is to demonstrate the overall flow of lowering a PyTorch
Expand Down Expand Up @@ -83,12 +83,12 @@ def forward(self, x):
%x : [#users=1] = placeholder[target=x]
%linear_weight : [#users=1] = get_attr[target=linear.weight]
%linear_bias : [#users=1] = get_attr[target=linear.bias]
%linear_1 : [#users=1] = call_function[target=fx2trt_oss.tracer.acc_tracer.acc_ops.linear](args = (), ...
%relu_1 : [#users=1] = call_function[target=fx2trt_oss.tracer.acc_tracer.acc_ops.relu](args = (), ...
%linear_1 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.linear](args = (), ...
%relu_1 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.relu](args = (), ...
return relu_1
graph():
%relu_1 : [#users=1] = placeholder[target=relu_1]
%linalg_norm_1 : [#users=1] = call_function[target=fx2trt_oss.tracer.acc_tracer.acc_ops.linalg_norm](args = (), ...
%linalg_norm_1 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.linalg_norm](args = (), ...
return linalg_norm_1
"""

Expand Down
4 changes: 2 additions & 2 deletions py/torch_tensorrt/fx/example/lower_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import torch
import torchvision
from fx2trt_oss.fx.lower import lower_to_trt
from fx2trt_oss.fx.utils import LowerPrecision
from torch_tensorrt.fx.lower import lower_to_trt
from torch_tensorrt.fx.utils import LowerPrecision


"""
Expand Down
8 changes: 4 additions & 4 deletions py/torch_tensorrt/fx/example/quantized_resnet_test.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import copy

import fx2trt_oss.tracer.acc_tracer.acc_tracer as acc_tracer

# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch.fx

import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer
import torchvision.models as models
from fx2trt_oss.fx import InputTensorSpec, TRTInterpreter, TRTModule
from fx2trt_oss.fx.utils import LowerPrecision
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx
from torch.fx.experimental.normalize import NormalizeArgs
from torch.fx.passes import shape_prop
from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter, TRTModule
from torch_tensorrt.fx.utils import LowerPrecision

rn18 = models.resnet18().eval()

Expand Down
38 changes: 23 additions & 15 deletions py/torch_tensorrt/fx/example/test_fx2trt.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,35 @@
import torch_tensorrt
import torch
import torch_tensorrt


class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(5,3)
self.linear = torch.nn.Linear(5, 3)
self.relu = torch.nn.functional.relu
def forward(self,x):

def forward(self, x):
x = self.linear(x)
x = self.relu(x)
return x


model = MyModel().eval() # torch module needs to be in eval (not training) mode
model = MyModel().eval() # torch module needs to be in eval (not training) mode

# torch tensorrt
inputs = [torch_tensorrt.Input(
(2,5),
dtype=torch.half,
)]
enabled_precisions = {torch.float, torch.half} # Run with fp16

trt_ts_module = torch_tensorrt.compile(model, inputs=inputs, enabled_precisions=enabled_precisions)

inputs_ts = [torch.ones(2,5)]
inputs = [
torch_tensorrt.Input(
(2, 5),
dtype=torch.half,
)
]
enabled_precisions = {torch.float, torch.half} # Run with fp16

trt_ts_module = torch_tensorrt.compile(
model, inputs=inputs, enabled_precisions=enabled_precisions
)

inputs_ts = [torch.ones(2, 5)]
inputs_ts = [i.cuda().half() for i in inputs_ts]
result = trt_ts_module(*inputs_ts)
print(result)
Expand All @@ -33,12 +39,14 @@ def forward(self,x):
print(ref)

# fx2trt
inputs_fx = [torch.ones((2,5))]
inputs_fx = [torch.ones((2, 5))]

model.cuda().half()
inputs_fx = [i.cuda().half() for i in inputs_fx]

trt_fx_module = torch_tensorrt.compile(model, ir="fx", inputs=inputs_fx, enabled_precisions={torch.half})
trt_fx_module = torch_tensorrt.compile(
model, ir="fx", inputs=inputs_fx, enabled_precisions={torch.half}
)
result = trt_fx_module(*inputs_fx)
print(result)

Expand Down
4 changes: 2 additions & 2 deletions py/torch_tensorrt/fx/example/torchdynamo_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import torch
import torchdynamo
import torchvision
from fx2trt_oss.fx.lower import lower_to_trt
from fx2trt_oss.fx.utils import LowerPrecision
from torch_tensorrt.fx.lower import lower_to_trt
from torch_tensorrt.fx.utils import LowerPrecision
from torchdynamo.optimizations import backends

"""
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/fx/fx2trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
import tensorrt as trt
import torch
import torch.fx
from .observer import Observer
from torch.fx.node import _get_qualified_name
from torch.fx.passes.shape_prop import TensorMetadata

from .converter_registry import CONVERTERS
from .input_tensor_spec import InputTensorSpec
from .observer import Observer
from .utils import get_dynamic_dims, LowerPrecision, torch_dtype_to_trt


Expand Down
16 changes: 10 additions & 6 deletions py/torch_tensorrt/fx/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,27 @@
import logging
from typing import Any, Callable, Sequence

from .tracer.acc_tracer import acc_tracer

# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch
import torch.fx as fx
import torch.nn as nn
from .lower_setting import LowerSetting
from .passes.pass_utils import decorate_method, validate_inference
from .passes.splitter_base import SplitResult
from torch.fx.passes.splitter_base import SplitResult

from .fx2trt import TRTInterpreter, TRTInterpreterResult
from .input_tensor_spec import InputTensorSpec
from .lower_setting import LowerSetting
from .passes.lower_pass_manager_builder import LowerPassManagerBuilder
from .passes.pass_utils import chain_passes, PassFunc
from .passes.pass_utils import (
chain_passes,
decorate_method,
PassFunc,
validate_inference,
)
from .tools.timing_cache_utils import TimingCacheManager
from .tools.trt_splitter import TRTSplitter, TRTSplitterSetting

from .tracer.acc_tracer import acc_tracer
from .trt_module import TRTModule
from .utils import LowerPrecision

Expand Down
10 changes: 4 additions & 6 deletions py/torch_tensorrt/fx/lower_setting.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import dataclasses as dc
from typing import List, Optional, Sequence, Set, Type

from .input_tensor_spec import InputTensorSpec
from .passes.lower_basic_pass import (
fuse_permute_linear,
fuse_permute_matmul,
)
from .utils import LowerPrecision
from torch import nn
from torch.fx.passes.pass_manager import PassManager

from .input_tensor_spec import InputTensorSpec
from .passes.lower_basic_pass import fuse_permute_linear, fuse_permute_matmul
from .utils import LowerPrecision


@dc.dataclass
class LowerSetting:
Expand Down
18 changes: 10 additions & 8 deletions py/torch_tensorrt/fx/passes/lower_basic_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import warnings
from typing import Any

from ..tracer.acc_tracer import acc_ops
import torch
import torch.fx
from torch.fx.experimental.const_fold import split_const_subgraphs

from ..observer import observable
from .pass_utils import log_before_after, validate_inference

from ..tracer.acc_tracer import acc_ops
from ..tracer.acc_tracer.acc_utils import get_attr
from torch.fx.experimental.const_fold import split_const_subgraphs
from .pass_utils import log_before_after, validate_inference

# Create an alias for module input type to avoid littering pyre-ignore for Any
# throughout the file.
Expand Down Expand Up @@ -46,15 +48,15 @@ def fuse_sparse_matmul_add(gm: torch.fx.GraphModule, input: Input):
def forward(self, x):
a = self.a
b = self.b
addmm_mm = fx2trt_oss.tracer.acc_tracer.acc_ops.matmul(input = a, other = b); a = b = None
addmm_add = fx2trt_oss.tracer.acc_tracer.acc_ops.add(input = addmm_mm, other = x); addmm_mm = x = None
addmm_mm = torch_tensorrt.fx.tracer.acc_tracer.acc_ops.matmul(input = a, other = b); a = b = None
addmm_add = torch_tensorrt.fx.tracer.acc_tracer.acc_ops.add(input = addmm_mm, other = x); addmm_mm = x = None
return addmm_add

After:
def forward(self, x):
a = self.a
b = self.b
linear_1 = fx2trt_oss.tracer.acc_tracer.acc_ops.linear(input = a, weight = b, bias = x); a = b = x = None
linear_1 = torch_tensorrt.fx.tracer.acc_tracer.acc_ops.linear(input = a, weight = b, bias = x); a = b = x = None
return linear_1
"""
counter = 0
Expand Down Expand Up @@ -198,8 +200,8 @@ def fuse_permute_matmul(gm: torch.fx.GraphModule, input: Input):
try:
# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
from fx2trt_oss.fx.converter_registry import tensorrt_converter
from fx2trt_oss.fx.converters.converter_utils import (
from torch_tensorrt.fx.converter_registry import tensorrt_converter
from torch_tensorrt.fx.converters.converter_utils import (
add_binary_elementwise_layer,
broadcast,
get_trt_tensor,
Expand Down
Loading