Skip to content

Commit a257b90

Browse files
committed
perf(langserver): increase performance of visitor a little bit more
1 parent f6a2667 commit a257b90

File tree

3 files changed

+75
-50
lines changed

3 files changed

+75
-50
lines changed

hatch.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ extra-dependencies = [
5858
python = "3.11"
5959
post-install-commands = ["pip install -U -e {root:uri}/../robotframework"]
6060

61+
[envs.rfdevel38]
62+
python = "3.8"
63+
post-install-commands = ["pip install -U -e {root:uri}/../robotframework"]
64+
6165
[envs.devel]
6266
python = "3.8"
6367

Lines changed: 61 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,38 @@
11
import ast
2-
from typing import Any, AsyncIterator, Callable, Dict, Iterator, Optional, Type
2+
from abc import ABC
3+
from collections import defaultdict
4+
5+
from robot.parsing.model.statements import Statement
6+
from typing_extensions import Any, AsyncIterator, Callable, Dict, Iterator, Optional, Type, Union
37

48
__all__ = ["iter_fields", "iter_child_nodes", "AsyncVisitor"]
59

610

11+
def _patch_robot() -> None:
12+
if hasattr(Statement, "_fields"):
13+
Statement._fields = ()
14+
15+
16+
_patch_robot()
17+
18+
719
def iter_fields(node: ast.AST) -> Iterator[Any]:
8-
"""
9-
Yield a tuple of ``(fieldname, value)`` for each field in ``node._fields``
10-
that is present on *node*.
11-
"""
1220
for field in node._fields:
1321
try:
1422
yield field, getattr(node, field)
1523
except AttributeError:
1624
pass
1725

1826

27+
def iter_field_values(node: ast.AST) -> Iterator[Any]:
28+
for field in node._fields:
29+
try:
30+
yield getattr(node, field)
31+
except AttributeError:
32+
pass
33+
34+
1935
def iter_child_nodes(node: ast.AST) -> Iterator[ast.AST]:
20-
"""
21-
Yield all direct child nodes of *node*, that is, all fields that are nodes
22-
and all items of fields that are lists of nodes.
23-
"""
2436
for _name, field in iter_fields(node):
2537
if isinstance(field, ast.AST):
2638
yield field
@@ -46,60 +58,69 @@ async def iter_nodes(node: ast.AST) -> AsyncIterator[ast.AST]:
4658
yield n
4759

4860

49-
class VisitorFinder:
50-
__NOT_SET = object()
61+
class _NotSet:
62+
pass
63+
5164

52-
def __init__(self) -> None:
53-
self.__cache: Dict[Type[Any], Optional[Callable[..., Any]]] = {}
65+
class VisitorFinder(ABC):
66+
__NOT_SET = _NotSet()
67+
__cls_finder_cache__: Dict[Type[Any], Union[Callable[..., Any], None, _NotSet]]
5468

55-
def __find_visitor(self, cls: Type[Any]) -> Optional[Callable[..., Any]]:
56-
if cls is ast.AST:
69+
def __init_subclass__(cls, **kwargs: Any) -> None:
70+
super().__init_subclass__(**kwargs)
71+
cls.__cls_finder_cache__ = defaultdict(lambda: cls.__NOT_SET)
72+
73+
@classmethod
74+
def __find_visitor(cls, node_cls: Type[Any]) -> Optional[Callable[..., Any]]:
75+
if node_cls is ast.AST:
5776
return None
58-
method_name = "visit_" + cls.__name__
59-
if hasattr(self, method_name):
60-
method = getattr(self, method_name)
61-
if callable(method):
62-
return method # type: ignore
63-
for base in cls.__bases__:
64-
method = self._find_visitor(base)
77+
method_name = "visit_" + node_cls.__name__
78+
method = getattr(cls, method_name, None)
79+
if callable(method):
80+
return method # type: ignore[no-any-return]
81+
for base in node_cls.__bases__:
82+
method = cls._find_visitor(base)
6583
if method:
66-
return method # type: ignore
84+
return method
6785
return None
6886

69-
def _find_visitor(self, cls: Type[Any]) -> Optional[Callable[..., Any]]:
70-
r = self.__cache.get(cls, self.__NOT_SET)
71-
if r is self.__NOT_SET:
72-
self.__cache[cls] = r = self.__find_visitor(cls)
73-
return r # type: ignore
87+
@classmethod
88+
def _find_visitor(cls, node_cls: Type[Any]) -> Optional[Callable[..., Any]]:
89+
result = cls.__cls_finder_cache__[node_cls]
90+
if result is cls.__NOT_SET:
91+
result = cls.__cls_finder_cache__[node_cls] = cls.__find_visitor(node_cls)
92+
return result # type: ignore[return-value]
7493

7594

7695
class AsyncVisitor(VisitorFinder):
7796
async def visit(self, node: ast.AST) -> None:
78-
visitor = self._find_visitor(type(node)) or self.generic_visit
79-
await visitor(node)
97+
visitor = self._find_visitor(type(node)) or self.__class__.generic_visit
98+
await visitor(self, node)
8099

81100
async def generic_visit(self, node: ast.AST) -> None:
82-
"""Called if no explicit visitor function exists for a node."""
83-
for _, value in iter_fields(node):
84-
if isinstance(value, list):
101+
for value in iter_field_values(node):
102+
if value is None:
103+
continue
104+
if isinstance(value, ast.AST):
105+
await self.visit(value)
106+
elif isinstance(value, list):
85107
for item in value:
86108
if isinstance(item, ast.AST):
87109
await self.visit(item)
88-
elif isinstance(value, ast.AST):
89-
await self.visit(value)
90110

91111

92112
class Visitor(VisitorFinder):
93113
def visit(self, node: ast.AST) -> None:
94-
visitor = self._find_visitor(type(node)) or self.generic_visit
95-
visitor(node)
114+
visitor = self._find_visitor(type(node)) or self.__class__.generic_visit
115+
visitor(self, node)
96116

97117
def generic_visit(self, node: ast.AST) -> None:
98-
"""Called if no explicit visitor function exists for a node."""
99-
for field, value in iter_fields(node):
118+
for value in iter_field_values(node):
119+
if value is None:
120+
continue
100121
if isinstance(value, list):
101122
for item in value:
102123
if isinstance(item, ast.AST):
103124
self.visit(item)
104-
elif isinstance(value, ast.AST):
125+
else:
105126
self.visit(value)

pyproject.toml

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ fail_under = 40
162162
[tool.ruff]
163163
line-length = 120
164164
target-version = "py38"
165-
extend-exclude = ["bundled/libs"]
165+
extend-exclude = ["bundled/libs", ".hatch"]
166166
ignore = ["E741", "N805", "N999", "RUF012"]
167167
select = [
168168
"E",
@@ -220,15 +220,15 @@ implicit_optional = true
220220
disallow_untyped_decorators = false
221221
disallow_subclassing_any = false
222222
exclude = [
223-
'\.mypy_cache',
224-
'\.venv',
225-
'\.hatch',
226-
'build',
227-
'dist',
228-
'out',
229-
'playground',
230-
'scripts',
231-
'bundled/libs',
223+
'^\.mypy_cache/',
224+
'^\.venv/',
225+
'^\.hatch/',
226+
"^build/",
227+
"^dist/",
228+
"^out/",
229+
"^playground/",
230+
"^scripts/",
231+
"^bundled/libs/",
232232
]
233233
mypy_path = [
234234
"typings",

0 commit comments

Comments
 (0)