1
1
import ast
2
- from typing import Any , AsyncIterator , Callable , Dict , Iterator , Optional , Type
2
+ from abc import ABC
3
+ from collections import defaultdict
4
+
5
+ from robot .parsing .model .statements import Statement
6
+ from typing_extensions import Any , AsyncIterator , Callable , Dict , Iterator , Optional , Type , Union
3
7
4
8
__all__ = ["iter_fields" , "iter_child_nodes" , "AsyncVisitor" ]
5
9
6
10
11
+ def _patch_robot () -> None :
12
+ if hasattr (Statement , "_fields" ):
13
+ Statement ._fields = ()
14
+
15
+
16
+ _patch_robot ()
17
+
18
+
7
19
def iter_fields (node : ast .AST ) -> Iterator [Any ]:
8
- """
9
- Yield a tuple of ``(fieldname, value)`` for each field in ``node._fields``
10
- that is present on *node*.
11
- """
12
20
for field in node ._fields :
13
21
try :
14
22
yield field , getattr (node , field )
15
23
except AttributeError :
16
24
pass
17
25
18
26
27
+ def iter_field_values (node : ast .AST ) -> Iterator [Any ]:
28
+ for field in node ._fields :
29
+ try :
30
+ yield getattr (node , field )
31
+ except AttributeError :
32
+ pass
33
+
34
+
19
35
def iter_child_nodes (node : ast .AST ) -> Iterator [ast .AST ]:
20
- """
21
- Yield all direct child nodes of *node*, that is, all fields that are nodes
22
- and all items of fields that are lists of nodes.
23
- """
24
36
for _name , field in iter_fields (node ):
25
37
if isinstance (field , ast .AST ):
26
38
yield field
@@ -46,60 +58,69 @@ async def iter_nodes(node: ast.AST) -> AsyncIterator[ast.AST]:
46
58
yield n
47
59
48
60
49
- class VisitorFinder :
50
- __NOT_SET = object ()
61
+ class _NotSet :
62
+ pass
63
+
51
64
52
- def __init__ (self ) -> None :
53
- self .__cache : Dict [Type [Any ], Optional [Callable [..., Any ]]] = {}
65
+ class VisitorFinder (ABC ):
66
+ __NOT_SET = _NotSet ()
67
+ __cls_finder_cache__ : Dict [Type [Any ], Union [Callable [..., Any ], None , _NotSet ]]
54
68
55
- def __find_visitor (self , cls : Type [Any ]) -> Optional [Callable [..., Any ]]:
56
- if cls is ast .AST :
69
+ def __init_subclass__ (cls , ** kwargs : Any ) -> None :
70
+ super ().__init_subclass__ (** kwargs )
71
+ cls .__cls_finder_cache__ = defaultdict (lambda : cls .__NOT_SET )
72
+
73
+ @classmethod
74
+ def __find_visitor (cls , node_cls : Type [Any ]) -> Optional [Callable [..., Any ]]:
75
+ if node_cls is ast .AST :
57
76
return None
58
- method_name = "visit_" + cls .__name__
59
- if hasattr (self , method_name ):
60
- method = getattr (self , method_name )
61
- if callable (method ):
62
- return method # type: ignore
63
- for base in cls .__bases__ :
64
- method = self ._find_visitor (base )
77
+ method_name = "visit_" + node_cls .__name__
78
+ method = getattr (cls , method_name , None )
79
+ if callable (method ):
80
+ return method # type: ignore[no-any-return]
81
+ for base in node_cls .__bases__ :
82
+ method = cls ._find_visitor (base )
65
83
if method :
66
- return method # type: ignore
84
+ return method
67
85
return None
68
86
69
- def _find_visitor (self , cls : Type [Any ]) -> Optional [Callable [..., Any ]]:
70
- r = self .__cache .get (cls , self .__NOT_SET )
71
- if r is self .__NOT_SET :
72
- self .__cache [cls ] = r = self .__find_visitor (cls )
73
- return r # type: ignore
87
+ @classmethod
88
+ def _find_visitor (cls , node_cls : Type [Any ]) -> Optional [Callable [..., Any ]]:
89
+ result = cls .__cls_finder_cache__ [node_cls ]
90
+ if result is cls .__NOT_SET :
91
+ result = cls .__cls_finder_cache__ [node_cls ] = cls .__find_visitor (node_cls )
92
+ return result # type: ignore[return-value]
74
93
75
94
76
95
class AsyncVisitor (VisitorFinder ):
77
96
async def visit (self , node : ast .AST ) -> None :
78
- visitor = self ._find_visitor (type (node )) or self .generic_visit
79
- await visitor (node )
97
+ visitor = self ._find_visitor (type (node )) or self .__class__ . generic_visit
98
+ await visitor (self , node )
80
99
81
100
async def generic_visit (self , node : ast .AST ) -> None :
82
- """Called if no explicit visitor function exists for a node."""
83
- for _ , value in iter_fields (node ):
84
- if isinstance (value , list ):
101
+ for value in iter_field_values (node ):
102
+ if value is None :
103
+ continue
104
+ if isinstance (value , ast .AST ):
105
+ await self .visit (value )
106
+ elif isinstance (value , list ):
85
107
for item in value :
86
108
if isinstance (item , ast .AST ):
87
109
await self .visit (item )
88
- elif isinstance (value , ast .AST ):
89
- await self .visit (value )
90
110
91
111
92
112
class Visitor (VisitorFinder ):
93
113
def visit (self , node : ast .AST ) -> None :
94
- visitor = self ._find_visitor (type (node )) or self .generic_visit
95
- visitor (node )
114
+ visitor = self ._find_visitor (type (node )) or self .__class__ . generic_visit
115
+ visitor (self , node )
96
116
97
117
def generic_visit (self , node : ast .AST ) -> None :
98
- """Called if no explicit visitor function exists for a node."""
99
- for field , value in iter_fields (node ):
118
+ for value in iter_field_values (node ):
119
+ if value is None :
120
+ continue
100
121
if isinstance (value , list ):
101
122
for item in value :
102
123
if isinstance (item , ast .AST ):
103
124
self .visit (item )
104
- elif isinstance ( value , ast . AST ) :
125
+ else :
105
126
self .visit (value )
0 commit comments