Skip to content

Commit

Permalink
slice_scatter decomposition
Browse files Browse the repository at this point in the history
changing decomposition pattern

slice scatter changes

Review comments address

Removing arange and replacing with range

slice_scatter adding to decomposition group

using aten::scatter in aten.slice_scatter

Correcting the slice_scatter case with aten::scatter use

removing unnecessary cases from slice_scatter impl and adding test case

changing for loop to torch.arange

Reverting back the torch.arange to for loop

Adding test case for 3d cases and removing the casting to torch.int64 and including it torch.ones

Removing aten.index in the decomposition ops
  • Loading branch information
apbose committed May 30, 2024
1 parent dfc31c7 commit 2b101dd
Show file tree
Hide file tree
Showing 2 changed files with 234 additions and 0 deletions.
39 changes: 39 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from torch._decomp import register_decomposition
from torch._ops import OpOverload
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim

from ._decomposition_groups import (
ENABLED_TORCH_DECOMPOSITIONS,
Expand Down Expand Up @@ -174,6 +175,44 @@ def empty_permuted_decomposition(*args, **kwargs) -> torch.Tensor:
return torch.empty([empty_size[l] for l in empty_permute], **kwargs).permute(perm)


@register_torch_trt_decomposition(
torch.ops.aten.slice_scatter.default, registry=TORCH_TRT_DECOMPOSITIONS
)
def slice_scatter_decomposition(
input_tensor: torch.Tensor,
src_tensor: torch.Tensor,
dim: int,
start: Optional[int] = None,
end: Optional[int] = None,
step: Optional[int] = None,
):
dim_size = input_tensor.shape[dim]
start = get_positive_dim(start, input_tensor.shape[dim])
if end is None:
end = dim_size
end = get_positive_dim(end, input_tensor.shape[dim])
if step is None:
step = 1

src_dim = src_tensor.shape
# step == 0 is not a valid torch case
# also src_dim should be equal to slice dimension

if start == 0 and end == dim_size and step == 1:
return src_tensor

cat_tensors = []
index_tensor_shape = []
for i, src_each_dim in enumerate(list(src_dim)):
if i != dim:
index_tensor_shape.append(src_each_dim)
for index in range(start, end, step):
cat_tensors.append(index * torch.ones(index_tensor_shape, dtype=torch.long))
index_tensor = torch.stack(cat_tensors, dim).cuda()
output_tensor = torch.scatter(input_tensor, dim, index_tensor, src_tensor)
return output_tensor


def get_decompositions(
enable_experimental_decompositions: bool = False,
) -> Dict[OpOverload, Callable[[Any], Any]]:
Expand Down
195 changes: 195 additions & 0 deletions tests/py/dynamo/lowering/test_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,201 @@ def forward(self, x):
f"The optimized model results shape and torch model results shape should be equal in empty_like",
)

def test_lowering_slice_scatter_dimOne_module(self):
class sliceScatter(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x, src, dim, start=None, end=None, step=1):
y = torch.ops.aten.slice_scatter(x, src, dim, start, end, step)
return y

# Operations expected to be removed in the traced graph after decompositions
expected_ops = {
torch.ops.aten.scatter.src,
}
unexpected_ops = {torch.ops.aten.select_scatter}

inputs = [torch.zeros(8, 8).cuda(), torch.ones(8, 2).cuda(), 1, 6, None, 1]

fx_graph = torch.fx.symbolic_trace(sliceScatter())
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEqual(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEqual(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
truncate_long_and_double=True,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"Slice_scatter TRT outputs don't match with the original model.",
)

def test_lowering_slice_scatter_dimZero_StepTwo_module(self):
class sliceScatter(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x, src, dim, start, end, step):
y = torch.ops.aten.slice_scatter.default(x, src, dim, start, end, step)
return y

# Operations expected to be removed in the traced graph after decompositions
expected_ops = {
torch.ops.aten.scatter.src,
}
unexpected_ops = {torch.ops.aten.slice_scatter}

inputs = [torch.zeros(8, 8).cuda(), torch.ones(2, 8).cuda(), 0, 2, 6, 2]

fx_graph = torch.fx.symbolic_trace(sliceScatter())

unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEqual(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEqual(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
truncate_long_and_double=True,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"Slice_scatter TRT outputs don't match with the original model.",
)

def test_lowering_slice_scatter_dimOne_3d_module(self):
class sliceScatter(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x, src, dim, start, end, step):
y = torch.ops.aten.slice_scatter.default(x, src, dim, start, end, step)
return y

# Operations expected to be removed in the traced graph after decompositions
expected_ops = {
torch.ops.aten.scatter.src,
}
unexpected_ops = {torch.ops.aten.slice_scatter}

inputs = [
torch.zeros(8, 8, 8).cuda(),
torch.ones(8, 2, 8).cuda(),
1,
6,
None,
1,
]

fx_graph = torch.fx.symbolic_trace(sliceScatter())

unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEqual(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEqual(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
truncate_long_and_double=True,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"Slice_scatter TRT outputs don't match with the original model.",
)


if __name__ == "__main__":
run_tests()

0 comments on commit 2b101dd

Please sign in to comment.