Skip to content

Commit

Permalink
perf: Factorize and improve main and extensions visitors
Browse files Browse the repository at this point in the history
  • Loading branch information
pawamoy committed Sep 18, 2021
1 parent 8d54c87 commit 9b27b56
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 22 deletions.
47 changes: 43 additions & 4 deletions src/griffe/extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

import ast
import enum
from typing import Type
from typing import TYPE_CHECKING, Type

if TYPE_CHECKING:
from griffe.visitor import _MainVisitor as MainVisitor # noqa: WPS450


class When(enum.Enum):
Expand All @@ -23,12 +26,39 @@ class When(enum.Enum):
visit_stops: int = 4


class Extension(ast.NodeVisitor):
class _BaseVisitor:
def visit(self, node: ast.AST, parent: ast.AST | None = None) -> None:
self._visit(node, parent=parent)

def generic_visit(self, node: ast.AST) -> None: # noqa: WPS231
# optimisation: got rid of the two generators iter_fields and iter_child_nodes
for field_name in node._fields: # noqa: WPS437
try:
field = getattr(node, field_name)
except AttributeError:
continue
if isinstance(field, ast.AST):
self.visit(field, parent=node)
elif isinstance(field, list):
for child in field:
if isinstance(child, ast.AST):
self.visit(child, parent=node)

def _run_specific_or_generic(self, node):
# optimisation: no extra variable, f-string instead of concatenation
getattr(self, f"visit_{node.__class__.__name__}", self.generic_visit)(node)

def _visit(self, node: ast.AST, parent: ast.AST | None = None) -> None:
return self._run_specific_or_generic(node)


class Extension(_BaseVisitor):
"""The node visitor extension base class, to inherit from."""

need_parents = False
when: When

def __init__(self, main_visitor: ast.NodeVisitor) -> None:
def __init__(self, main_visitor: MainVisitor) -> None:
"""Initialize the visitor extension.
Arguments:
Expand All @@ -50,6 +80,15 @@ def __init__(self, *extensions_classes: Type[Extension]) -> None:
self._classes: list[Type[Extension]] = list(extensions_classes)
self._instances: dict[When, list[Extension]] = {}

@property
def need_parents(self):
"""Tell if any of the contained extensions needs access to the whole parents chain while visiting.
Returns:
True or False.
"""
return any(class_.need_parents for class_ in self._classes)

@property
def when_visit_starts(self) -> list[Extension]:
"""Return the visitors that run when the visit starts.
Expand Down Expand Up @@ -105,7 +144,7 @@ def add(self, *extensions_classes: Type[Extension]) -> None:
"""
self._classes.extend(extensions_classes)

def instantiate(self, main_visitor: ast.NodeVisitor) -> Extensions:
def instantiate(self, main_visitor: MainVisitor) -> Extensions:
"""Clear and instantiate the visitor classes.
Arguments:
Expand Down
45 changes: 27 additions & 18 deletions src/griffe/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from griffe.collections import lines_collection
from griffe.dataclasses import Argument, Class, Decorator, Docstring, Function, Module
from griffe.extensions import Extensions
from griffe.extensions.base import _BaseVisitor # noqa: WPS450


def visit(
Expand All @@ -34,7 +35,7 @@ def visit(
Returns:
The module, with its members populated.
"""
return _Visitor(module_name, filepath, code, extensions or Extensions()).get_module()
return _MainVisitor(module_name, filepath, code, extensions or Extensions()).get_module()


def _get_docstring(node):
Expand All @@ -48,7 +49,7 @@ def _get_docstring(node):
return None


class _Visitor(ast.NodeVisitor):
class _MainVisitor(_BaseVisitor): # noqa: WPS338
def __init__(
self,
module_name: str,
Expand All @@ -57,41 +58,49 @@ def __init__(
extensions: Extensions,
) -> None:
super().__init__()
self.module_name = module_name
self.filepath = filepath
self.code = code
self.extensions = extensions.instantiate(self)
self.module_name: str = module_name
self.filepath: Path = filepath
self.code: str = code
self.extensions: Extensions = extensions.instantiate(self)
# self.scope = defaultdict(dict)
self.root = None
self.node = None
self.root: ast.AST | None = None
self.parent: ast.AST | None = None
self.current: Module | Class | Function = None # type: ignore
self.in_decorator: bool = False
if self.extensions.need_parents:
self._visit = self._visit_set_parents # type: ignore

def _visit_set_parents(self, node: ast.AST, parent: ast.AST | None = None) -> None:
node.parent = parent # type: ignore
self._run_specific_or_generic(node)

def get_module(self) -> Module:
top_node = ast.parse(self.code)
link_tree(top_node)
# optimisation: equivalent to ast.parse, but with optimize=1 to remove assert statements
# TODO: with options, could use optimize=2 to remove docstrings
top_node = compile(self.code, mode="exec", filename=str(self.filepath), flags=ast.PyCF_ONLY_AST, optimize=1)
self.visit(top_node)
return self.current.module # type: ignore # there's always a module after the visit

def visit(self, node: ast.AST) -> None:
def visit(self, node: ast.AST, parent: ast.AST | None = None) -> None:
for start_visitor in self.extensions.when_visit_starts:
start_visitor.visit(node)
super().visit(node)
start_visitor.visit(node, parent)
super().visit(node, parent)
for stop_visitor in self.extensions.when_visit_stops:
stop_visitor.visit(node)
stop_visitor.visit(node, parent)

def generic_visit(self, node: ast.AST) -> None:
def generic_visit(self, node: ast.AST) -> None: # noqa: WPS231
for start_visitor in self.extensions.when_children_visit_starts:
start_visitor.visit(node)
super().generic_visit(node)
for stop_visitor in self.extensions.when_children_visit_stops:
stop_visitor.visit(node)

def visit_Module(self, node) -> None:
self.current = Module(self.module_name, filepath=self.filepath, docstring=_get_docstring(node))
self.current = Module(name=self.module_name, filepath=self.filepath, docstring=_get_docstring(node))
self.generic_visit(node)

def visit_ClassDef(self, node) -> None:
class_ = Class(node.name, lineno=node.lineno, endlineno=node.end_lineno, docstring=_get_docstring(node))
class_ = Class(name=node.name, lineno=node.lineno, endlineno=node.end_lineno, docstring=_get_docstring(node))
self.current[node.name] = class_
self.current = class_
self.generic_visit(node)
Expand Down Expand Up @@ -147,7 +156,7 @@ def visit_FunctionDef(self, node) -> None: # noqa: WPS231
returns = None

function = Function(
node.name,
name=node.name,
lineno=lineno,
endlineno=node.end_lineno,
arguments=arguments,
Expand Down

0 comments on commit 9b27b56

Please sign in to comment.