Skip to content

Commit

Permalink
perf: Delegate children computation at runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
pawamoy committed Sep 18, 2021
1 parent fee304d commit 8d54c87
Showing 1 changed file with 21 additions and 39 deletions.
60 changes: 21 additions & 39 deletions src/griffe/extended_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import ast
import inspect
from ast import AST, iter_child_nodes # noqa: WPS458
from ast import AST # noqa: WPS458
from functools import cached_property


Expand All @@ -33,8 +33,22 @@ class RootNodeError(Exception):

class _ExtendedAST:
@cached_property
def children(self) -> list[AST]:
return list(iter_child_nodes(self)) # type: ignore
def children(self) -> list[AST]: # noqa: WPS231
children = []
for field_name in self._fields: # type: ignore # noqa: WPS437
try:
field = getattr(self, field_name)
except AttributeError:
continue
if isinstance(field, AST):
field.parent = self # type: ignore
children.append(field)
elif isinstance(field, list):
for child in field:
if isinstance(child, AST):
child.parent = self # type: ignore
children.append(child)
return children

@cached_property
def position(self) -> int:
Expand All @@ -57,7 +71,7 @@ def next_siblings(self) -> list[AST]:

@cached_property
def siblings(self) -> list[AST]:
return reversed(self.previous_siblings) + self.next_siblings # type: ignore
return [*reversed(self.previous_siblings), *self.next_siblings] # type: ignore

@cached_property
def previous(self) -> AST:
Expand Down Expand Up @@ -91,45 +105,13 @@ def last_child(self) -> AST: # noqa: A003
_patched = False


def extend_ast(force: bool = False) -> None:
"""Extend the base `ast.AST` class to provide more functionality.
Arguments:
force: Whether to force re-patching if it was already done.
"""
def extend_ast() -> None:
"""Extend the base `ast.AST` class to provide more functionality."""
global _patched # noqa: WPS420
if _patched and not force:
if _patched:
return
for name, member in inspect.getmembers(ast):
if name != "AST" and inspect.isclass(member):
if AST in member.__bases__: # noqa: WPS609
member.__bases__ = (*member.__bases__, _ExtendedAST) # noqa: WPS609
_patched = True # noqa: WPS122,WPS442


def link_tree(root_node: AST) -> None:
"""Link nodes between them.
This will set the `parent` and `children` attributes on every node in the tree.
Arguments:
root_node: The root node, to start from.
"""
root_node.parent = None # type: ignore
_link_tree(root_node)


def _link_tree(node: AST) -> None: # noqa: WPS231
for field_name in node._fields: # noqa: WPS437
try:
field = getattr(node, field_name)
except AttributeError:
continue
if isinstance(field, AST):
field.parent = node # type: ignore
_link_tree(field)
elif isinstance(field, list):
for child in field:
if isinstance(child, AST):
child.parent = node # type: ignore
_link_tree(child)

0 comments on commit 8d54c87

Please sign in to comment.