Skip to content

Arm backend: Add sign decomposition pass and test #12159

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from .decompose_ne_pass import DecomposeNotEqualPass # noqa
from .decompose_round_pass import DecomposeRoundPass # noqa
from .decompose_select import DecomposeSelectPass # noqa
from .decompose_sign_pass import DecomposeSignPass # noqa
from .decompose_silu_pass import DecomposeSiluPass # noqa
from .decompose_sinh_pass import DecomposeSinhPass # noqa
from .decompose_softmax_pass import DecomposeSoftmaxPass # noqa
Expand Down
3 changes: 3 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
DecomposeNotEqualPass,
DecomposeRoundPass,
DecomposeSelectPass,
DecomposeSignPass,
DecomposeSiluPass,
DecomposeSinhPass,
DecomposeSoftmaxPass,
Expand Down Expand Up @@ -158,6 +159,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(ConvertIntPowToMuls())
self.add_pass(CastBoolToInt8Pass())
self.add_pass(DecomposeSinhPass())
self.add_pass(DecomposeSignPass())
self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI())
self.add_pass(DecomposeEmbeddingPass())
self.add_pass(FuseQuantizedActivationPass())
Expand Down Expand Up @@ -242,6 +244,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(DecomposeScaledDotProductAttention())
self.add_pass(DecomposeRoundPass())
self.add_pass(CastBoolToInt8Pass())
self.add_pass(DecomposeSignPass())
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
self.add_pass(ScalarsToAttributePass())
self.add_pass(DecomposeGroupNormPass())
Expand Down
73 changes: 73 additions & 0 deletions backends/arm/_passes/decompose_sign_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch

from executorch.backends.arm._passes import ArmPass
from executorch.exir.dialects._ops import ops as exir_ops


# For MI case
edge_sign = exir_ops.edge.aten.sign.default
# For BI case
aten_sign = torch.ops.aten.sign.default


def get_ops(op):
"""Returns the appropriate operator functions based on the input operator."""
if op == edge_sign:
return (
exir_ops.edge.aten.gt.Scalar,
exir_ops.edge.aten.lt.Scalar,
exir_ops.edge.aten.where.self,
exir_ops.edge.aten.neg.default,
exir_ops.edge.aten.mul.Scalar,
exir_ops.edge.aten.add.Scalar,
)
elif op == aten_sign:
return (
torch.ops.aten.gt.Scalar,
torch.ops.aten.lt.Scalar,
torch.ops.aten.where.self,
torch.ops.aten.neg.default,
torch.ops.aten.mul.Scalar,
torch.ops.aten.add.Scalar,
)
else:
raise ValueError(f"Unsupported operator: {op}")


class DecomposeSignPass(ArmPass):
"""Decomposes the sign operator into a sequence of operations that are supported by the Arm backend."""

def call_operator(self, op, args, kwargs, meta):
if op not in (edge_sign, aten_sign):
return super().call_operator(op, args, kwargs, meta)

gt_op, lt_op, where_op, neg_op, mul_op, add_op = get_ops(op)

x = args[0]

gt_mask = super().call_operator(gt_op, (x, 0.0), {}, meta, updated=True)
lt_mask = super().call_operator(lt_op, (x, 0.0), {}, meta, updated=True)

zeros = super().call_operator(mul_op, (x, 0.0), {}, meta, updated=True)
ones = super().call_operator(add_op, (zeros, 1.0), {}, meta, updated=True)
neg_ones = super().call_operator(neg_op, (ones,), {}, meta, updated=True)

negative_tensor = super().call_operator(
where_op, (lt_mask, neg_ones, zeros), {}, meta, updated=True
)
positive_tensor = super().call_operator(
where_op, (gt_mask, ones, zeros), {}, meta, updated=True
)

return super().call_operator(
where_op,
(lt_mask, negative_tensor, positive_tensor),
{},
meta,
updated=True,
)
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def is_node_supported(
exir_ops.edge.aten.sinh.default,
exir_ops.edge.aten.atan.default,
exir_ops.edge.aten.acosh.default,
exir_ops.edge.aten.sign.default,
]

return supported
Expand Down
1 change: 1 addition & 0 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def _match_pattern(
torch.ops.aten.sinh.default,
torch.ops.aten.atan.default,
torch.ops.aten.acosh.default,
torch.ops.aten.sign.default,
]

_one_to_one_shared_input_qspec = [
Expand Down
86 changes: 86 additions & 0 deletions backends/arm/test/ops/test_sign.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import pytest
import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import (
EthosU55PipelineBI,
EthosU85PipelineBI,
TosaPipelineBI,
TosaPipelineMI,
)

aten_op = "torch.ops.aten.sign.default"
exir_op = "executorch_exir_dialects_edge__ops_aten__sign_default"

input_t1 = Tuple[torch.Tensor]

test_data_suite = {
"zeros": torch.zeros(3, 5),
"ones": torch.ones(4, 4),
"neg_ones": -torch.ones(4, 4),
"mixed_signs": torch.tensor([[-2.0, -1.0, 0.0, 1.0, 2.0]]),
"positive_ramp": torch.arange(0.1, 1.1, 0.2),
"negative_ramp": torch.arange(-1.0, -0.1, 0.2),
"small_values": torch.tensor([-1e-7, 0.0, 1e-7]),
"rand": torch.rand(10, 10) - 0.5,
"rand_alt_shape": torch.rand(10, 3, 5) - 0.5,
"high_magnitude": torch.tensor([-1e6, -10.0, 0.0, 10.0, 1e6]),
}


class Sign(torch.nn.Module):
def forward(self, x: torch.Tensor):
return torch.sign(x)


@common.parametrize("test_data", test_data_suite)
def test_sign_tosa_MI(test_data: Tuple):
pipeline = TosaPipelineMI[input_t1](
Sign(),
(test_data,),
aten_op=aten_op,
exir_op=exir_op,
)
pipeline.run()


@common.parametrize("test_data", test_data_suite)
def test_sign_tosa_BI(test_data: Tuple):
pipeline = TosaPipelineBI[input_t1](
Sign(),
(test_data,),
aten_op=[],
exir_op=exir_op,
)
pipeline.run()


@common.XfailIfNoCorstone300
@common.parametrize("test_data", test_data_suite)
@pytest.mark.xfail(reason="where.self not supported on U55")
def test_sign_u55_BI(test_data: Tuple):
pipeline = EthosU55PipelineBI[input_t1](
Sign(),
(test_data,),
aten_ops=[],
exir_ops=exir_op,
)
pipeline.run()


@common.XfailIfNoCorstone320
@common.parametrize("test_data", test_data_suite)
def test_sign_u85_BI(test_data: Tuple):
pipeline = EthosU85PipelineBI[input_t1](
Sign(),
(test_data,),
aten_ops=[],
exir_ops=exir_op,
)
pipeline.run()
Loading