Skip to content

Properly set mutable buffer lifespans #12182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions exir/memory_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,11 @@ def _is_inplace_node(node: torch.fx.Node) -> bool:


def update_tensor_lifetime(
node: torch.fx.Node, spec: TensorSpec, node_idx: int
node: torch.fx.Node,
spec: TensorSpec,
node_idx: int,
max_node_idx: int,
gs: Optional[ExportGraphSignature] = None,
) -> None:
r"""
Update the lifetime of the tensor to cover node_idx. A tensor's lifetime
Expand All @@ -317,7 +321,12 @@ def update_tensor_lifetime(
start = 0
else:
start = node_idx if start is None or start > node_idx else start
end = node_idx if end is None or end < node_idx else end

if node.op == "placeholder" and _is_mutable_buffer(node, gs):
# mutable buffers are never freed
end = max_node_idx
else:
end = node_idx if end is None or end < node_idx else end
spec.lifetime = [start, end]


Expand Down Expand Up @@ -497,7 +506,7 @@ def update_all_tensors_lifetime(
Set the lifetime for all the tensors encountered in the Fx graph.
"""
specs = set()

max_node_idx = len(graph_module.graph.nodes) - 1
for node_idx, node in enumerate(graph_module.graph.nodes):
for spec in collect_specs_from_nodes(
filter_nodes(itertools.chain([node], node.args, node.kwargs.values())),
Expand All @@ -509,7 +518,7 @@ def update_all_tensors_lifetime(
do_assertion=False,
ignore_dynamic_unbound_tensor=False,
):
update_tensor_lifetime(node, spec, node_idx)
update_tensor_lifetime(node, spec, node_idx, max_node_idx, graph_signature)
specs.add(spec)
return specs

Expand Down
41 changes: 41 additions & 0 deletions exir/tests/test_memory_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,47 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
.val.allocation_info.memory_offset_high,
)

def test_mutable_buffers_infinite_lifespan(self) -> None:
class Simple(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_buffer("state", torch.zeros(1))

def forward(self, x: torch.Tensor) -> torch.Tensor:
self.state.index_put_(
[
torch.tensor([0]),
],
x,
)
y = x + self.state
z = x * y
return z

model = Simple()
inputs = (torch.ones(1),)

et = to_edge(export(model, inputs, strict=True)).to_executorch(
ExecutorchBackendConfig(
emit_mutable_buffer_names=True, run_reinplace_pass=True
)
)

serialized_state = et.executorch_program.execution_plan[0].values[0].val
self.assertEqual(
serialized_state.extra_tensor_info.fully_qualified_name, "state"
)
memory_base = serialized_state.allocation_info.memory_offset_low
memory_size = memory_base + 4 # 4 bytes for a single float
for value in et.executorch_program.execution_plan[0].values[1:]:
val = value.val
if hasattr(val, "allocation_info") and val.allocation_info is not None:
not_overlapping = (
val.allocation_info.memory_offset_low < memory_base
or val.allocation_info.memory_offset_low >= memory_size
)
self.assertTrue(not_overlapping)

def test_constants_not_memory_planned(self) -> None:
class Simple(torch.nn.Module):
def __init__(self) -> None:
Expand Down
Loading