diff --git a/docs/schema.json b/docs/schema.json index d272198e..4476bd6e 100644 --- a/docs/schema.json +++ b/docs/schema.json @@ -339,41 +339,9 @@ } ], "$defs": { - "name": { - "type": "object", - "properties": { - "source": { - "title": "The annotation as written in the source code.", - "markdownDescription": "https://mkdocstrings.github.io/griffe/reference/griffe/dataclasses/#griffe.expressions.Name", - "type": "string" - }, - "full": { - "title": "The full path of the .", - "markdownDescription": "https://mkdocstrings.github.io/griffe/reference/griffe/dataclasses/#griffe.expressions.Name", - "type": "string" - } - }, - "additionalProperties": false, - "required": [ - "source", - "full" - ] - }, "expression": { - "type": "array", - "items": { - "oneOf": [ - { - "type": "string" - }, - { - "$ref": "#/$defs/name" - }, - { - "$ref": "#/$defs/expression" - } - ] - } + "type": "object", + "additionalProperties": true }, "annotation": { "oneOf": [ @@ -383,9 +351,6 @@ { "type": "string" }, - { - "$ref": "#/$defs/name" - }, { "$ref": "#/$defs/expression" } diff --git a/src/griffe/agents/inspector.py b/src/griffe/agents/inspector.py index 77496d50..3ef11205 100644 --- a/src/griffe/agents/inspector.py +++ b/src/griffe/agents/inspector.py @@ -48,7 +48,7 @@ from pathlib import Path from griffe.docstrings.parsers import Parser - from griffe.expressions import Expression, Name + from griffe.expressions import Expr empty = Signature.empty @@ -405,7 +405,7 @@ def inspect_attribute(self, node: ObjectNode) -> None: """ self.handle_attribute(node) - def handle_attribute(self, node: ObjectNode, annotation: str | Name | Expression | None = None) -> None: + def handle_attribute(self, node: ObjectNode, annotation: str | Expr | None = None) -> None: """Handle an attribute. Parameters: @@ -478,7 +478,7 @@ def _convert_parameter(parameter: SignatureParameter, parent: Module | Class) -> return Parameter(name, annotation=annotation, kind=kind, default=default) -def _convert_object_to_annotation(obj: Any, parent: Module | Class) -> str | Name | Expression | None: +def _convert_object_to_annotation(obj: Any, parent: Module | Class) -> str | Expr | None: # even when *we* import future annotations, # the object from which we get a signature # can come from modules which did *not* import them, diff --git a/src/griffe/agents/nodes/__init__.py b/src/griffe/agents/nodes/__init__.py index 3d1b5642..34bdd579 100644 --- a/src/griffe/agents/nodes/__init__.py +++ b/src/griffe/agents/nodes/__init__.py @@ -15,7 +15,12 @@ ast_siblings, ) from griffe.agents.nodes._docstrings import get_docstring -from griffe.agents.nodes._expressions import ( +from griffe.agents.nodes._imports import relative_to_absolute +from griffe.agents.nodes._names import get_instance_names, get_name, get_names +from griffe.agents.nodes._parameters import get_call_keyword_arguments +from griffe.agents.nodes._runtime import ObjectKind, ObjectNode +from griffe.agents.nodes._values import get_value, safe_get_value +from griffe.expressions import ( get_annotation, get_base_class, get_condition, @@ -25,11 +30,6 @@ safe_get_condition, safe_get_expression, ) -from griffe.agents.nodes._imports import relative_to_absolute -from griffe.agents.nodes._names import get_instance_names, get_name, get_names -from griffe.agents.nodes._parameters import get_call_keyword_arguments -from griffe.agents.nodes._runtime import ObjectKind, ObjectNode -from griffe.agents.nodes._values import get_value, safe_get_value __all__ = [ "ast_children", diff --git a/src/griffe/agents/nodes/_all.py b/src/griffe/agents/nodes/_all.py index 8fd56a15..f95eee3f 100644 --- a/src/griffe/agents/nodes/_all.py +++ b/src/griffe/agents/nodes/_all.py @@ -4,11 +4,10 @@ import ast from contextlib import suppress -from functools import partial from typing import TYPE_CHECKING, Any, Callable from griffe.agents.nodes._values import get_value -from griffe.expressions import Name +from griffe.expressions import ExprName from griffe.logger import LogLevel, get_logger if TYPE_CHECKING: @@ -18,32 +17,32 @@ logger = get_logger(__name__) -def _extract_constant(node: ast.Constant, parent: Module) -> list[str | Name]: +def _extract_constant(node: ast.Constant, parent: Module) -> list[str | ExprName]: return [node.value] -def _extract_name(node: ast.Name, parent: Module) -> list[str | Name]: - return [Name(node.id, partial(parent.resolve, node.id))] +def _extract_name(node: ast.Name, parent: Module) -> list[str | ExprName]: + return [ExprName(node.id, parent)] -def _extract_starred(node: ast.Starred, parent: Module) -> list[str | Name]: +def _extract_starred(node: ast.Starred, parent: Module) -> list[str | ExprName]: return _extract(node.value, parent) -def _extract_sequence(node: ast.List | ast.Set | ast.Tuple, parent: Module) -> list[str | Name]: +def _extract_sequence(node: ast.List | ast.Set | ast.Tuple, parent: Module) -> list[str | ExprName]: sequence = [] for elt in node.elts: sequence.extend(_extract(elt, parent)) return sequence -def _extract_binop(node: ast.BinOp, parent: Module) -> list[str | Name]: +def _extract_binop(node: ast.BinOp, parent: Module) -> list[str | ExprName]: left = _extract(node.left, parent) right = _extract(node.right, parent) return left + right -_node_map: dict[type, Callable[[Any, Module], list[str | Name]]] = { +_node_map: dict[type, Callable[[Any, Module], list[str | ExprName]]] = { ast.Constant: _extract_constant, ast.Name: _extract_name, ast.Starred: _extract_starred, @@ -54,11 +53,11 @@ def _extract_binop(node: ast.BinOp, parent: Module) -> list[str | Name]: } -def _extract(node: ast.AST, parent: Module) -> list[str | Name]: +def _extract(node: ast.AST, parent: Module) -> list[str | ExprName]: return _node_map[type(node)](node, parent) -def get__all__(node: ast.Assign | ast.AugAssign, parent: Module) -> list[str | Name]: +def get__all__(node: ast.Assign | ast.AugAssign, parent: Module) -> list[str | ExprName]: """Get the values declared in `__all__`. Parameters: @@ -77,7 +76,7 @@ def safe_get__all__( node: ast.Assign | ast.AugAssign, parent: Module, log_level: LogLevel = LogLevel.debug, # TODO: set to error when we handle more things -) -> list[str | Name]: +) -> list[str | ExprName]: """Safely (no exception) extract values in `__all__`. Parameters: diff --git a/src/griffe/agents/nodes/_expressions.py b/src/griffe/agents/nodes/_expressions.py deleted file mode 100644 index 1c7be6e4..00000000 --- a/src/griffe/agents/nodes/_expressions.py +++ /dev/null @@ -1,562 +0,0 @@ -"""This module contains utilities for building information from nodes.""" - -from __future__ import annotations - -import ast -import sys -from functools import partial -from typing import TYPE_CHECKING, Any, Callable, Sequence - -from griffe.expressions import Expression, Name -from griffe.logger import LogLevel, get_logger - -if TYPE_CHECKING: - from pathlib import Path - - from griffe.dataclasses import Class, Module - - -logger = get_logger(__name__) - - -def _join(sequence: Sequence, item: str | Name | Expression) -> list: - if not sequence: - return [] - new_sequence = [sequence[0]] - for element in sequence[1:]: - new_sequence.append(item) - new_sequence.append(element) - return new_sequence - - -def _build_add(node: ast.Add, parent: Module | Class, **kwargs: Any) -> str: - return "+" - - -def _build_and(node: ast.And, parent: Module | Class, **kwargs: Any) -> str: - return "and" - - -def _build_arguments(node: ast.arguments, parent: Module | Class, **kwargs: Any) -> str: - return ", ".join(arg.arg for arg in node.args) - - -def _build_attribute(node: ast.Attribute, parent: Module | Class, **kwargs: Any) -> Expression: - left = _build(node.value, parent, **kwargs) - - if isinstance(left, str): - resolver = f"str.{node.attr}" - else: - - def resolver() -> str: # type: ignore[misc] - return f"{left.full}.{node.attr}" - - right = Name(node.attr, resolver, first_attr_name=False) - return Expression(left, ".", right) - - -def _build_binop(node: ast.BinOp, parent: Module | Class, **kwargs: Any) -> Expression: - left = _build(node.left, parent, **kwargs) - right = _build(node.right, parent, **kwargs) - return Expression(left, " ", _build(node.op, parent, **kwargs), " ", right) - - -def _build_bitand(node: ast.BitAnd, parent: Module | Class, **kwargs: Any) -> str: - return "&" - - -def _build_bitor(node: ast.BitOr, parent: Module | Class, **kwargs: Any) -> str: - return "|" - - -def _build_bitxor(node: ast.BitXor, parent: Module | Class, **kwargs: Any) -> str: - return "^" - - -def _build_boolop(node: ast.BoolOp, parent: Module | Class, **kwargs: Any) -> Expression: - return Expression( - *_join([_build(value, parent, **kwargs) for value in node.values], f" {_build(node.op, parent, **kwargs)} "), - ) - - -def _build_call(node: ast.Call, parent: Module | Class, **kwargs: Any) -> Expression: - positional_args = Expression(*_join([_build(arg, parent, **kwargs) for arg in node.args], ", ")) - keyword_args = Expression(*_join([_build(kwarg, parent, **kwargs) for kwarg in node.keywords], ", ")) - args: Expression | str - if positional_args and keyword_args: - args = Expression(positional_args, ", ", keyword_args) - elif positional_args: - args = positional_args - elif keyword_args: - args = keyword_args - else: - args = "" - return Expression(_build(node.func, parent, **kwargs), "(", args, ")") - - -def _build_compare(node: ast.Compare, parent: Module | Class, **kwargs: Any) -> Expression: - left = _build(node.left, parent, **kwargs) - ops = [_build(op, parent, **kwargs) for op in node.ops] - comparators = [_build(comparator, parent, **kwargs) for comparator in node.comparators] - return Expression(left, " ", *_join([Expression(op, " ", comp) for op, comp in zip(ops, comparators)], " ")) - - -def _build_comprehension(node: ast.comprehension, parent: Module | Class, **kwargs: Any) -> Expression: - target = _build(node.target, parent, **kwargs) - iterable = _build(node.iter, parent, **kwargs) - conditions = [_build(condition, parent, **kwargs) for condition in node.ifs] - value = Expression("for ", target, " in ", iterable) - if conditions: - value.extend((" if ", *_join(conditions, " if "))) - if node.is_async: - value.insert(0, "async ") - return value - - -def _build_constant( - node: ast.Constant, - parent: Module | Class, - *, - in_formatted_str: bool = False, - in_joined_str: bool = False, - parse_strings: bool = False, - literal_strings: bool = False, - **kwargs: Any, -) -> str | Name | Expression: - if isinstance(node.value, str): - if in_joined_str and not in_formatted_str: - # We're in a f-string, not in a formatted value, don't keep quotes. - return node.value - if parse_strings and not literal_strings: - # We're in a place where a string could be a type annotation - # (and not in a Literal[...] type annotation). - # We parse the string and build from the resulting nodes again. - # If we fail to parse it (syntax errors), we consider it's a literal string and log a message. - try: - parsed = compile( - node.value, - mode="eval", - filename="", - flags=ast.PyCF_ONLY_AST, - optimize=1, - ) - except SyntaxError: - logger.debug( - f"Tried and failed to parse {node.value!r} as Python code, " - "falling back to using it as a string literal " - "(postponed annotations might help: https://peps.python.org/pep-0563/)", - ) - else: - return _build(parsed.body, parent, **kwargs) # type: ignore[attr-defined] - return {type(...): lambda _: "..."}.get(type(node.value), repr)(node.value) - - -def _build_dict(node: ast.Dict, parent: Module | Class, **kwargs: Any) -> Expression: - pairs = zip(node.keys, node.values) - body = [ - Expression("None" if key is None else _build(key, parent, **kwargs), ": ", _build(value, parent, **kwargs)) - for key, value in pairs - ] - return Expression("{", Expression(*_join(body, ", ")), "}") - - -def _build_dictcomp(node: ast.DictComp, parent: Module | Class, **kwargs: Any) -> Expression: - key = _build(node.key, parent, **kwargs) - value = _build(node.value, parent, **kwargs) - generators = [_build(gen, parent, **kwargs) for gen in node.generators] - return Expression("{", key, ": ", value, Expression(*_join(generators, " ")), "}") - - -def _build_div(node: ast.Div, parent: Module | Class, **kwargs: Any) -> str: - return "/" - - -def _build_eq(node: ast.Eq, parent: Module | Class, **kwargs: Any) -> str: - return "==" - - -def _build_floordiv(node: ast.FloorDiv, parent: Module | Class, **kwargs: Any) -> str: - return "//" - - -def _build_formatted(node: ast.FormattedValue, parent: Module | Class, **kwargs: Any) -> Expression: - return Expression("{", _build(node.value, parent, in_formatted_str=True, **kwargs), "}") - - -def _build_generatorexp(node: ast.GeneratorExp, parent: Module | Class, **kwargs: Any) -> Expression: - element = _build(node.elt, parent, **kwargs) - generators = [_build(gen, parent, **kwargs) for gen in node.generators] - return Expression(element, " ", Expression(*_join(generators, " "))) - - -def _build_gte(node: ast.GtE, parent: Module | Class, **kwargs: Any) -> str: - return ">=" - - -def _build_gt(node: ast.Gt, parent: Module | Class, **kwargs: Any) -> str: - return ">" - - -def _build_ifexp(node: ast.IfExp, parent: Module | Class, **kwargs: Any) -> Expression: - return Expression( - _build(node.body, parent, **kwargs), - " if ", - _build(node.test, parent, **kwargs), - " else ", - _build(node.orelse, parent, **kwargs), - ) - - -def _build_invert(node: ast.Invert, parent: Module | Class, **kwargs: Any) -> str: - return "~" - - -def _build_in(node: ast.In, parent: Module | Class, **kwargs: Any) -> str: - return "in" - - -def _build_is(node: ast.Is, parent: Module | Class, **kwargs: Any) -> str: - return "is" - - -def _build_isnot(node: ast.IsNot, parent: Module | Class, **kwargs: Any) -> str: - return "is not" - - -def _build_joinedstr(node: ast.JoinedStr, parent: Module | Class, **kwargs: Any) -> Expression: - return Expression("f'", *[_build(value, parent, in_joined_str=True) for value in node.values], "'") - - -def _build_keyword(node: ast.keyword, parent: Module | Class, **kwargs: Any) -> Expression: - if node.arg is None: - return Expression("**", _build(node.value, parent, **kwargs)) - return Expression(node.arg, "=", _build(node.value, parent, **kwargs)) - - -def _build_lambda(node: ast.Lambda, parent: Module | Class, **kwargs: Any) -> Expression: - return Expression("lambda ", _build(node.args, parent, **kwargs), ": ", _build(node.body, parent, **kwargs)) - - -def _build_list(node: ast.List, parent: Module | Class, **kwargs: Any) -> Expression: - return Expression("[", *_join([_build(el, parent, **kwargs) for el in node.elts], ", "), "]") - - -def _build_listcomp(node: ast.ListComp, parent: Module | Class, **kwargs: Any) -> Expression: - element = _build(node.elt, parent, **kwargs) - generators = [_build(gen, parent, **kwargs) for gen in node.generators] - return Expression("[", element, *_join(generators, " "), "]") - - -def _build_lshift(node: ast.LShift, parent: Module | Class, **kwargs: Any) -> str: - return "<<" - - -def _build_lte(node: ast.LtE, parent: Module | Class, **kwargs: Any) -> str: - return "<=" - - -def _build_lt(node: ast.Lt, parent: Module | Class, **kwargs: Any) -> str: - return "<" - - -def _build_matmult(node: ast.MatMult, parent: Module | Class, **kwargs: Any) -> str: - return "@" - - -def _build_mod(node: ast.Mod, parent: Module | Class, **kwargs: Any) -> str: - return "%" - - -def _build_mult(node: ast.Mult, parent: Module | Class, **kwargs: Any) -> str: - return "*" - - -def _build_name(node: ast.Name, parent: Module | Class, **kwargs: Any) -> Name: - return Name(node.id, partial(parent.resolve, node.id)) - - -def _build_named_expr(node: ast.NamedExpr, parent: Module | Class, **kwargs: Any) -> Expression: - return Expression("(", _build(node.target, parent, **kwargs), " := ", _build(node.value, parent, **kwargs), ")") - - -def _build_not(node: ast.Not, parent: Module | Class, **kwargs: Any) -> str: - return "not " - - -def _build_noteq(node: ast.NotEq, parent: Module | Class, **kwargs: Any) -> str: - return "!=" - - -def _build_notin(node: ast.NotIn, parent: Module | Class, **kwargs: Any) -> str: - return "not in" - - -def _build_or(node: ast.Or, parent: Module | Class, **kwargs: Any) -> str: - return "or" - - -def _build_pow(node: ast.Pow, parent: Module | Class, **kwargs: Any) -> str: - return "**" - - -def _build_rshift(node: ast.RShift, parent: Module | Class, **kwargs: Any) -> str: - return ">>" - - -def _build_set(node: ast.Set, parent: Module | Class, **kwargs: Any) -> Expression: - return Expression("{", *_join([_build(el, parent, **kwargs) for el in node.elts], ", "), "}") - - -def _build_setcomp(node: ast.SetComp, parent: Module | Class, **kwargs: Any) -> Expression: - element = _build(node.elt, parent, **kwargs) - generators = [_build(gen, parent, **kwargs) for gen in node.generators] - return Expression("{", element, " ", *_join(generators, " "), "}") - - -def _build_slice(node: ast.Slice, parent: Module | Class, **kwargs: Any) -> Expression: - lower = _build(node.lower, parent, **kwargs) if node.lower else "" - upper = _build(node.upper, parent, **kwargs) if node.upper else "" - value = Expression(lower, ":", upper) - if node.step: - value.extend((":", _build(node.step, parent, **kwargs))) - return value - - -def _build_starred(node: ast.Starred, parent: Module | Class, **kwargs: Any) -> Expression: - return Expression("*", _build(node.value, parent, **kwargs)) - - -def _build_sub(node: ast.Sub, parent: Module | Class, **kwargs: Any) -> str: - return "-" - - -def _build_subscript( - node: ast.Subscript, - parent: Module | Class, - *, - parse_strings: bool = False, - literal_strings: bool = False, - in_subscript: bool = False, - **kwargs: Any, -) -> Expression: - left = _build(node.value, parent, **kwargs) - if parse_strings: - if left.full in {"typing.Literal", "typing_extensions.Literal"}: # type: ignore[union-attr] - literal_strings = True - subscript = _build( - node.slice, - parent, - parse_strings=True, - literal_strings=literal_strings, - in_subscript=True, - **kwargs, - ) - else: - subscript = _build(node.slice, parent, in_subscript=True, **kwargs) - return Expression(left, "[", subscript, "]") - - -def _build_tuple( - node: ast.Tuple, - parent: Module | Class, - *, - in_subscript: bool = False, - **kwargs: Any, -) -> Expression: - values = _join([_build(el, parent, **kwargs) for el in node.elts], ", ") - if in_subscript: - return Expression(*values) - return Expression("(", *values, ")") - - -def _build_uadd(node: ast.UAdd, parent: Module | Class, **kwargs: Any) -> str: - return "+" - - -def _build_unaryop(node: ast.UnaryOp, parent: Module | Class, **kwargs: Any) -> Expression: - return Expression(_build(node.op, parent, **kwargs), _build(node.operand, parent, **kwargs)) - - -def _build_usub(node: ast.USub, parent: Module | Class, **kwargs: Any) -> str: - return "-" - - -def _build_yield(node: ast.Yield, parent: Module | Class, **kwargs: Any) -> str | Name | Expression: - if node.value is None: - return repr(None) - return _build(node.value, parent, **kwargs) - - -_node_map: dict[type, Callable[[Any, Module | Class], str | Name | Expression]] = { - ast.Add: _build_add, - ast.And: _build_and, - ast.arguments: _build_arguments, - ast.Attribute: _build_attribute, - ast.BinOp: _build_binop, - ast.BitAnd: _build_bitand, - ast.BitOr: _build_bitor, - ast.BitXor: _build_bitxor, - ast.BoolOp: _build_boolop, - ast.Call: _build_call, - ast.Compare: _build_compare, - ast.comprehension: _build_comprehension, - ast.Constant: _build_constant, - ast.Dict: _build_dict, - ast.DictComp: _build_dictcomp, - ast.Div: _build_div, - ast.Eq: _build_eq, - ast.FloorDiv: _build_floordiv, - ast.FormattedValue: _build_formatted, - ast.GeneratorExp: _build_generatorexp, - ast.Gt: _build_gt, - ast.GtE: _build_gte, - ast.IfExp: _build_ifexp, - ast.In: _build_in, - ast.Invert: _build_invert, - ast.Is: _build_is, - ast.IsNot: _build_isnot, - ast.JoinedStr: _build_joinedstr, - ast.keyword: _build_keyword, - ast.Lambda: _build_lambda, - ast.List: _build_list, - ast.ListComp: _build_listcomp, - ast.LShift: _build_lshift, - ast.Lt: _build_lt, - ast.LtE: _build_lte, - ast.MatMult: _build_matmult, - ast.Mod: _build_mod, - ast.Mult: _build_mult, - ast.Name: _build_name, - ast.NamedExpr: _build_named_expr, - ast.Not: _build_not, - ast.NotEq: _build_noteq, - ast.NotIn: _build_notin, - ast.Or: _build_or, - ast.Pow: _build_pow, - ast.RShift: _build_rshift, - ast.Set: _build_set, - ast.SetComp: _build_setcomp, - ast.Slice: _build_slice, - ast.Starred: _build_starred, - ast.Sub: _build_sub, - ast.Subscript: _build_subscript, - ast.Tuple: _build_tuple, - ast.UAdd: _build_uadd, - ast.UnaryOp: _build_unaryop, - ast.USub: _build_usub, - ast.Yield: _build_yield, -} - -# TODO: remove once Python 3.8 support is dropped -if sys.version_info < (3, 9): - - def _build_extslice(node: ast.ExtSlice, parent: Module | Class, **kwargs: Any) -> Expression: - return Expression(*_join([_build(dim, parent, **kwargs) for dim in node.dims], ",")) - - def _build_index(node: ast.Index, parent: Module | Class, **kwargs: Any) -> str | Name | Expression: - return _build(node.value, parent, **kwargs) - - _node_map[ast.ExtSlice] = _build_extslice - _node_map[ast.Index] = _build_index - - -def _build(node: ast.AST, parent: Module | Class, **kwargs: Any) -> str | Name | Expression: - return _node_map[type(node)](node, parent, **kwargs) - - -def get_expression( - node: ast.AST | None, - parent: Module | Class, - *, - parse_strings: bool | None = None, -) -> str | Name | Expression | None: - """Build an expression from an AST. - - Parameters: - node: The annotation node. - parent: The parent used to resolve the name. - parse_strings: Whether to try and parse strings as type annotations. - - Returns: - A string or resovable name or expression. - """ - if node is None: - return None - if parse_strings is None: - try: - module = parent.module - except ValueError: - parse_strings = False - else: - parse_strings = not module.imports_future_annotations - return _build(node, parent, parse_strings=parse_strings) - - -def safe_get_expression( - node: ast.AST | None, - parent: Module | Class, - *, - parse_strings: bool | None = None, - log_level: LogLevel | None = LogLevel.error, - msg_format: str = "{path}:{lineno}: Failed to get expression from {node_class}: {error}", -) -> str | Name | Expression | None: - """Safely (no exception) build a resolvable annotation. - - Parameters: - node: The annotation node. - parent: The parent used to resolve the name. - parse_strings: Whether to try and parse strings as type annotations. - log_level: Log level to use to log a message. None to disable logging. - msg_format: A format string for the log message. Available placeholders: - path, lineno, node, error. - - Returns: - A string or resovable name or expression. - """ - try: - return get_expression(node, parent, parse_strings=parse_strings) - except Exception as error: # noqa: BLE001 - if log_level is None: - return None - node_class = node.__class__.__name__ - try: - path: Path | str = parent.relative_filepath - except ValueError: - path = "" - lineno = node.lineno # type: ignore[union-attr] - message = msg_format.format(path=path, lineno=lineno, node_class=node_class, error=error) - getattr(logger, log_level.value)(message) - return None - - -_msg_format = "{path}:{lineno}: Failed to get %s expression from {node_class}: {error}" -get_annotation = partial(get_expression, parse_strings=None) -safe_get_annotation = partial( - safe_get_expression, - parse_strings=None, - msg_format=_msg_format % "annotation", -) -get_base_class = partial(get_expression, parse_strings=False) -safe_get_base_class = partial( - safe_get_expression, - parse_strings=False, - msg_format=_msg_format % "base class", -) -get_condition = partial(get_expression, parse_strings=False) -safe_get_condition = partial( - safe_get_expression, - parse_strings=False, - msg_format=_msg_format % "condition", -) - - -__all__ = [ - "get_annotation", - "get_base_class", - "get_condition", - "get_expression", - "safe_get_annotation", - "safe_get_base_class", - "safe_get_condition", - "safe_get_expression", -] diff --git a/src/griffe/agents/nodes/_parameters.py b/src/griffe/agents/nodes/_parameters.py index 2ed11db2..69f0e99f 100644 --- a/src/griffe/agents/nodes/_parameters.py +++ b/src/griffe/agents/nodes/_parameters.py @@ -4,11 +4,11 @@ from typing import TYPE_CHECKING, Any -from griffe.agents.nodes._expressions import safe_get_expression +from griffe.expressions import safe_get_expression from griffe.logger import get_logger if TYPE_CHECKING: - from ast import Call + import ast from griffe.dataclasses import Class, Module @@ -16,7 +16,7 @@ logger = get_logger(__name__) -def get_call_keyword_arguments(node: Call, parent: Module | Class) -> dict[str, Any]: +def get_call_keyword_arguments(node: ast.Call, parent: Module | Class) -> dict[str, Any]: """Get the list of keyword argument names and values from a Call node. Parameters: diff --git a/src/griffe/agents/nodes/_runtime.py b/src/griffe/agents/nodes/_runtime.py index 7e0da3ba..50e4d93b 100644 --- a/src/griffe/agents/nodes/_runtime.py +++ b/src/griffe/agents/nodes/_runtime.py @@ -2,12 +2,12 @@ from __future__ import annotations -import enum import inspect from functools import cached_property from inspect import getmodule from typing import Any, ClassVar, Sequence +from griffe.enumerations import ObjectKind from griffe.logger import get_logger logger = get_logger(__name__) @@ -20,40 +20,6 @@ } -class ObjectKind(enum.Enum): - """Enumeration for the different kinds of objects.""" - - MODULE: str = "module" - """Modules.""" - CLASS: str = "class" - """Classes.""" - STATICMETHOD: str = "staticmethod" - """Static methods.""" - CLASSMETHOD: str = "classmethod" - """Class methods.""" - METHOD_DESCRIPTOR: str = "method_descriptor" - """Method descriptors.""" - METHOD: str = "method" - """Methods.""" - BUILTIN_METHOD: str = "builtin_method" - """Built-in ethods.""" - COROUTINE: str = "coroutine" - """Coroutines""" - FUNCTION: str = "function" - """Functions.""" - BUILTIN_FUNCTION: str = "builtin_function" - """Built-in functions.""" - CACHED_PROPERTY: str = "cached_property" - """Cached properties.""" - PROPERTY: str = "property" - """Properties.""" - ATTRIBUTE: str = "attribute" - """Attributes.""" - - def __str__(self) -> str: - return self.value - - class ObjectNode: """Helper class to represent an object tree. diff --git a/src/griffe/agents/visitor.py b/src/griffe/agents/visitor.py index c97eabee..bf52730b 100644 --- a/src/griffe/agents/visitor.py +++ b/src/griffe/agents/visitor.py @@ -25,7 +25,6 @@ safe_get_annotation, safe_get_base_class, safe_get_condition, - safe_get_expression, ) from griffe.collections import LinesCollection, ModulesCollection from griffe.dataclasses import ( @@ -42,14 +41,13 @@ Parameters, ) from griffe.exceptions import AliasResolutionError, CyclicAliasError, LastNodeError -from griffe.expressions import Expression +from griffe.expressions import Expr, safe_get_expression from griffe.extensions import Extensions if TYPE_CHECKING: from pathlib import Path from griffe.docstrings.parsers import Parser - from griffe.expressions import Name builtin_decorators = { @@ -369,7 +367,7 @@ def handle_function(self, node: ast.AsyncFunctionDef | ast.FunctionDef, labels: # handle parameters parameters = Parameters() - annotation: str | Name | Expression | None + annotation: str | Expr | None posonlyargs = node.args.posonlyargs @@ -552,7 +550,7 @@ def visit_importfrom(self, node: ast.ImportFrom) -> None: def handle_attribute( self, node: ast.Assign | ast.AnnAssign, - annotation: str | Name | Expression | None = None, + annotation: str | Expr | None = None, ) -> None: """Handle an attribute (assignment) node. @@ -577,9 +575,9 @@ def handle_attribute( except KeyError: # unsupported nodes, like subscript return - if isinstance(annotation, Expression) and annotation.is_classvar: + if isinstance(annotation, Expr) and annotation.is_classvar: # explicit classvar: class attribute only - annotation = annotation[2] + annotation = annotation.slice # type: ignore[attr-defined] labels.add("class-attribute") elif node.value: # attribute assigned at class-level: available in instances as well diff --git a/src/griffe/dataclasses.py b/src/griffe/dataclasses.py index c75d746e..5a99aa70 100644 --- a/src/griffe/dataclasses.py +++ b/src/griffe/dataclasses.py @@ -6,7 +6,6 @@ from __future__ import annotations -import enum import inspect from collections import defaultdict from contextlib import suppress @@ -16,39 +15,22 @@ from griffe.c3linear import c3linear_merge from griffe.docstrings.parsers import Parser, parse +from griffe.enumerations import Kind, ParameterKind from griffe.exceptions import AliasResolutionError, BuiltinModuleError, CyclicAliasError, NameResolutionError -from griffe.expressions import Name +from griffe.expressions import ExprCall, ExprName from griffe.logger import get_logger from griffe.mixins import GetMembersMixin, ObjectAliasMixin, SerializationMixin, SetMembersMixin if TYPE_CHECKING: from griffe.collections import LinesCollection, ModulesCollection from griffe.docstrings.dataclasses import DocstringSection - from griffe.expressions import Expression + from griffe.expressions import Expr from functools import cached_property logger = get_logger(__name__) -class ParameterKind(enum.Enum): - """Enumeration of the different parameter kinds. - - Attributes: - positional_only: Positional-only parameter. - positional_or_keyword: Positional or keyword parameter. - var_positional: Variadic positional parameter. - keyword_only: Keyword-only parameter. - var_keyword: Variadic keyword parameter. - """ - - positional_only: str = "positional-only" - positional_or_keyword: str = "positional or keyword" - var_positional: str = "variadic positional" - keyword_only: str = "keyword-only" - var_keyword: str = "variadic keyword" - - class Decorator: """This class represents decorators. @@ -57,7 +39,7 @@ class Decorator: endlineno: The ending line number. """ - def __init__(self, value: str | Name | Expression, *, lineno: int | None, endlineno: int | None) -> None: + def __init__(self, value: str | Expr, *, lineno: int | None, endlineno: int | None) -> None: """Initialize the decorator. Parameters: @@ -65,15 +47,15 @@ def __init__(self, value: str | Name | Expression, *, lineno: int | None, endlin lineno: The starting line number. endlineno: The ending line number. """ - self.value: str | Name | Expression = value + self.value: str | Expr = value self.lineno: int | None = lineno self.endlineno: int | None = endlineno @property def callable_path(self) -> str: """The path of the callable used as decorator.""" - value = self.value if isinstance(self.value, str) else self.value.full - return value.split("(", 1)[0] + value = self.value.function if isinstance(self.value, ExprCall) else self.value + return value if isinstance(value, str) else value.canonical_path def as_dict(self, **kwargs: Any) -> dict[str, Any]: # noqa: ARG002 """Return this decorator's data as a dictionary. @@ -197,9 +179,9 @@ def __init__( self, name: str, *, - annotation: str | Name | Expression | None = None, + annotation: str | Expr | None = None, kind: ParameterKind | None = None, - default: str | Name | Expression | None = None, + default: str | Expr | None = None, ) -> None: """Initialize the parameter. @@ -210,9 +192,9 @@ def __init__( default: The parameter default, if any. """ self.name: str = name - self.annotation: str | Name | Expression | None = annotation + self.annotation: str | Expr | None = annotation self.kind: ParameterKind | None = kind - self.default: str | Name | Expression | None = default + self.default: str | Expr | None = default def __str__(self) -> str: param = f"{self.name}: {self.annotation} = {self.default}" @@ -293,23 +275,6 @@ def add(self, parameter: Parameter) -> None: raise ValueError(f"parameter {parameter.name} already present") -class Kind(enum.Enum): - """Enumeration of the different objects kinds. - - Attributes: - MODULE: The module kind. - CLASS: The class kind. - FUNCTION: The function kind. - ATTRIBUTE: The attribute kind. - """ - - MODULE: str = "module" - CLASS: str = "class" - FUNCTION: str = "function" - ATTRIBUTE: str = "attribute" - ALIAS: str = "alias" - - class Object(GetMembersMixin, SetMembersMixin, ObjectAliasMixin, SerializationMixin): """An abstract class representing a Python object. @@ -361,7 +326,7 @@ def __init__( self.members: dict[str, Object | Alias] = {} self.labels: set[str] = set() self.imports: dict[str, str] = {} - self.exports: set[str] | list[str | Name] | None = None + self.exports: set[str] | list[str | ExprName] | None = None self.aliases: dict[str, Alias] = {} self.runtime: bool = runtime self.extra: dict[str, dict[str, Any]] = defaultdict(dict) @@ -982,7 +947,7 @@ def imports(self) -> dict[str, str]: # noqa: D102 return self.final_target.imports @property - def exports(self) -> set[str] | list[str | Name] | None: # noqa: D102 + def exports(self) -> set[str] | list[str | ExprName] | None: # noqa: D102 return self.final_target.exports @property @@ -1092,7 +1057,7 @@ def _filepath(self) -> Path | list[Path] | None: return cast(Module, self.target)._filepath @property - def bases(self) -> list[Name | Expression | str]: # noqa: D102 + def bases(self) -> list[Expr | str]: # noqa: D102 return cast(Class, self.target).bases @property @@ -1112,11 +1077,11 @@ def parameters(self) -> Parameters: # noqa: D102 return cast(Function, self.target).parameters @property - def returns(self) -> str | Name | Expression | None: # noqa: D102 + def returns(self) -> str | Expr | None: # noqa: D102 return cast(Function, self.target).returns @returns.setter - def returns(self, returns: str | Name | Expression | None) -> None: + def returns(self, returns: str | Expr | None) -> None: cast(Function, self.target).returns = returns @property @@ -1128,15 +1093,15 @@ def deleter(self) -> Function | None: # noqa: D102 return cast(Function, self.target).deleter @property - def value(self) -> str | Name | Expression | None: # noqa: D102 + def value(self) -> str | Expr | None: # noqa: D102 return cast(Attribute, self.target).value @property - def annotation(self) -> str | Name | Expression | None: # noqa: D102 + def annotation(self) -> str | Expr | None: # noqa: D102 return cast(Attribute, self.target).annotation @annotation.setter - def annotation(self, annotation: str | Name | Expression | None) -> None: + def annotation(self, annotation: str | Expr | None) -> None: cast(Attribute, self.target).annotation = annotation @property @@ -1410,7 +1375,7 @@ class Class(Object): def __init__( self, *args: Any, - bases: Sequence[Name | Expression | str] | None = None, + bases: Sequence[Expr | str] | None = None, decorators: list[Decorator] | None = None, **kwargs: Any, ) -> None: @@ -1423,7 +1388,7 @@ def __init__( **kwargs: See [`griffe.dataclasses.Object`][]. """ super().__init__(*args, **kwargs) - self.bases: list[Name | Expression | str] = list(bases) if bases else [] + self.bases: list[Expr | str] = list(bases) if bases else [] self.decorators: list[Decorator] = decorators or [] self.overloads: dict[str, list[Function]] = defaultdict(list) @@ -1455,12 +1420,7 @@ def resolved_bases(self) -> list[Object]: """ resolved_bases = [] for base in self.bases: - if isinstance(base, str): - base_path = base - elif isinstance(base, Name): - base_path = base.full - else: - base_path = base.without_subscript.full + base_path = base if isinstance(base, str) else base.canonical_path try: resolved_base = self.modules_collection[base_path] if resolved_base.is_alias: @@ -1510,7 +1470,7 @@ def __init__( self, *args: Any, parameters: Parameters | None = None, - returns: str | Name | Expression | None = None, + returns: str | Expr | None = None, decorators: list[Decorator] | None = None, **kwargs: Any, ) -> None: @@ -1525,14 +1485,14 @@ def __init__( """ super().__init__(*args, **kwargs) self.parameters: Parameters = parameters or Parameters() - self.returns: str | Name | Expression | None = returns + self.returns: str | Expr | None = returns self.decorators: list[Decorator] = decorators or [] self.setter: Function | None = None self.deleter: Function | None = None self.overloads: list[Function] | None = None @property - def annotation(self) -> str | Name | Expression | None: + def annotation(self) -> str | Expr | None: """Return the return annotation. Returns: @@ -1564,8 +1524,8 @@ class Attribute(Object): def __init__( self, *args: Any, - value: str | Name | Expression | None = None, - annotation: str | Name | Expression | None = None, + value: str | Expr | None = None, + annotation: str | Expr | None = None, **kwargs: Any, ) -> None: """Initialize the function. @@ -1577,8 +1537,8 @@ def __init__( **kwargs: See [`griffe.dataclasses.Object`][]. """ super().__init__(*args, **kwargs) - self.value: str | Name | Expression | None = value - self.annotation: str | Name | Expression | None = annotation + self.value: str | Expr | None = value + self.annotation: str | Expr | None = annotation def as_dict(self, **kwargs: Any) -> dict[str, Any]: """Return this function's data as a dictionary. diff --git a/src/griffe/diff.py b/src/griffe/diff.py index c8610a71..bcea8bac 100644 --- a/src/griffe/diff.py +++ b/src/griffe/diff.py @@ -3,13 +3,13 @@ from __future__ import annotations import contextlib -import enum from pathlib import Path from typing import Any, Iterable, Iterator from colorama import Fore, Style from griffe.dataclasses import Alias, Attribute, Class, Function, Object, ParameterKind +from griffe.enumerations import BreakageKind, ExplanationStyle from griffe.exceptions import AliasResolutionError from griffe.git import WORKTREE_PREFIX from griffe.logger import get_logger @@ -22,30 +22,6 @@ logger = get_logger(__name__) -class ExplanationStyle(enum.Enum): - """An enumeration of the possible styles for explanations.""" - - ONE_LINE: str = "oneline" - VERBOSE: str = "verbose" - - -class BreakageKind(enum.Enum): - """An enumeration of the possible breakages.""" - - PARAMETER_MOVED: str = "Positional parameter was moved" - PARAMETER_REMOVED: str = "Parameter was removed" - PARAMETER_CHANGED_KIND: str = "Parameter kind was changed" - PARAMETER_CHANGED_DEFAULT: str = "Parameter default was changed" - PARAMETER_CHANGED_REQUIRED: str = "Parameter is now required" - PARAMETER_ADDED_REQUIRED: str = "Parameter was added as required" - RETURN_CHANGED_TYPE: str = "Return types are incompatible" - OBJECT_REMOVED: str = "Public object was removed" - OBJECT_CHANGED_KIND: str = "Public object points to a different kind of object" - ATTRIBUTE_CHANGED_TYPE: str = "Attribute types are incompatible" - ATTRIBUTE_CHANGED_VALUE: str = "Attribute value was changed" - CLASS_REMOVED_BASE: str = "Base class was removed" - - class Breakage: """Breakages can explain what broke from a version to another.""" @@ -336,10 +312,10 @@ class ClassRemovedBaseBreakage(Breakage): kind: BreakageKind = BreakageKind.CLASS_REMOVED_BASE def _format_old_value(self) -> str: - return "[" + ", ".join(base.full for base in self.old_value) + "]" + return "[" + ", ".join(base.canonical_path for base in self.old_value) + "]" def _format_new_value(self) -> str: - return "[" + ", ".join(base.full for base in self.new_value) + "]" + return "[" + ", ".join(base.canonical_path for base in self.new_value) + "]" # TODO: decorators! diff --git a/src/griffe/docstrings/dataclasses.py b/src/griffe/docstrings/dataclasses.py index 490b1f16..29070444 100644 --- a/src/griffe/docstrings/dataclasses.py +++ b/src/griffe/docstrings/dataclasses.py @@ -2,13 +2,14 @@ from __future__ import annotations -import enum from typing import TYPE_CHECKING +from griffe.enumerations import DocstringSectionKind + if TYPE_CHECKING: from typing import Any, Literal - from griffe.dataclasses import Expression, Name + from griffe.expressions import Expr # Elements ----------------------------------------------- @@ -20,7 +21,7 @@ class DocstringElement: description: The element description. """ - def __init__(self, *, description: str, annotation: str | Name | Expression | None = None) -> None: + def __init__(self, *, description: str, annotation: str | Expr | None = None) -> None: """Initialize the element. Parameters: @@ -28,7 +29,7 @@ def __init__(self, *, description: str, annotation: str | Name | Expression | No description: The element description. """ self.description: str = description - self.annotation: str | Name | Expression | None = annotation + self.annotation: str | Expr | None = annotation def as_dict(self, **kwargs: Any) -> dict[str, Any]: # noqa: ARG002 """Return this element's data as a dictionary. @@ -58,7 +59,7 @@ def __init__( name: str, *, description: str, - annotation: str | Name | Expression | None = None, + annotation: str | Expr | None = None, value: str | None = None, ) -> None: """Initialize the element. @@ -92,7 +93,7 @@ class DocstringAdmonition(DocstringElement): """This class represents an admonition.""" @property - def kind(self) -> str | Name | Expression | None: + def kind(self) -> str | Expr | None: """Return the kind of this admonition. Returns: @@ -101,7 +102,7 @@ def kind(self) -> str | Name | Expression | None: return self.annotation @kind.setter - def kind(self, value: str | Name | Expression) -> None: + def kind(self, value: str | Expr) -> None: self.annotation = value @property @@ -177,23 +178,6 @@ class DocstringAttribute(DocstringNamedElement): # Sections ----------------------------------------------- -class DocstringSectionKind(enum.Enum): - """The possible section kinds.""" - - text = "text" - parameters = "parameters" - other_parameters = "other parameters" - raises = "raises" - warns = "warns" - returns = "returns" - yields = "yields" - receives = "receives" - examples = "examples" - attributes = "attributes" - deprecated = "deprecated" - admonition = "admonition" - - class DocstringSection: """This class represents a docstring section.""" diff --git a/src/griffe/docstrings/google.py b/src/griffe/docstrings/google.py index df3c8019..f2958077 100644 --- a/src/griffe/docstrings/google.py +++ b/src/griffe/docstrings/google.py @@ -30,12 +30,13 @@ DocstringYield, ) from griffe.docstrings.utils import parse_annotation, warning -from griffe.expressions import Expression, Name +from griffe.expressions import ExprName if TYPE_CHECKING: from typing import Any, Literal, Pattern from griffe.dataclasses import Docstring + from griffe.expressions import Expr _warn = warning(__name__) @@ -172,7 +173,7 @@ def _read_parameters( warn_unknown_params: bool = True, ) -> tuple[list[DocstringParameter], int]: parameters = [] - annotation: str | Name | Expression | None + annotation: str | Expr | None block, new_offset = _read_block_items(docstring, offset=offset) @@ -265,7 +266,7 @@ def _read_attributes_section( attributes = [] block, new_offset = _read_block_items(docstring, offset=offset) - annotation: str | Name | Expression | None = None + annotation: str | Expr | None = None for line_number, attr_lines in block: try: name_with_type, description = attr_lines[0].split(":", 1) @@ -305,7 +306,7 @@ def _read_raises_section( exceptions = [] block, new_offset = _read_block_items(docstring, offset=offset) - annotation: str | Name | Expression + annotation: str | Expr for line_number, exception_lines in block: try: annotation, description = exception_lines[0].split(":", 1) @@ -381,18 +382,18 @@ def _read_returns_section( raise ValueError if len(block) > 1: if annotation.is_tuple: - annotation = annotation.tuple_item(index) + annotation = annotation.slice.elements[index] else: if annotation.is_iterator: - return_item = annotation.iterator_item() + return_item = annotation.slice elif annotation.is_generator: - _, _, return_item = annotation.generator_items() + return_item = annotation.slice.elements[2] else: raise ValueError - if isinstance(return_item, Name): + if isinstance(return_item, ExprName): annotation = return_item elif return_item.is_tuple: - annotation = return_item.tuple_item(index) + annotation = return_item.slice.elements[index] else: annotation = return_item @@ -435,15 +436,15 @@ def _read_yields_section( with suppress(AttributeError, KeyError, ValueError): annotation = docstring.parent.returns # type: ignore[union-attr] if annotation.is_iterator: - yield_item = annotation.iterator_item() + yield_item = annotation.slice elif annotation.is_generator: - yield_item, _, _ = annotation.generator_items() + yield_item = annotation.slice.elements[0] else: raise ValueError - if isinstance(yield_item, Name): + if isinstance(yield_item, ExprName): annotation = yield_item elif yield_item.is_tuple: - annotation = yield_item.tuple_item(index) + annotation = yield_item.slice.elements[index] else: annotation = yield_item @@ -486,11 +487,11 @@ def _read_receives_section( with suppress(AttributeError, KeyError): annotation = docstring.parent.returns # type: ignore[union-attr] if annotation.is_generator: - _, receives_item, _ = annotation.generator_items() - if isinstance(receives_item, Name): + receives_item = annotation.slice.elements[1] + if isinstance(receives_item, ExprName): annotation = receives_item elif receives_item.is_tuple: - annotation = receives_item.tuple_item(index) + annotation = receives_item.slice.elements[index] else: annotation = receives_item diff --git a/src/griffe/docstrings/numpy.py b/src/griffe/docstrings/numpy.py index 5486608d..ddbd2bd4 100644 --- a/src/griffe/docstrings/numpy.py +++ b/src/griffe/docstrings/numpy.py @@ -66,13 +66,14 @@ DocstringYield, ) from griffe.docstrings.utils import parse_annotation, warning -from griffe.expressions import Expression, Name +from griffe.expressions import ExprName from griffe.logger import LogLevel if TYPE_CHECKING: from typing import Any, Literal, Pattern from griffe.dataclasses import Docstring + from griffe.expressions import Expr _warn = warning(__name__) @@ -245,7 +246,7 @@ def _read_parameters( **options: Any, ) -> tuple[list[DocstringParameter], int]: parameters = [] - annotation: str | Name | Expression | None + annotation: str | Expr | None items, new_offset = _read_block_items(docstring, offset=offset, **options) @@ -396,18 +397,18 @@ def _read_returns_section( raise ValueError if len(items) > 1: if annotation.is_tuple: - annotation = annotation.tuple_item(index) + annotation = annotation.slice.elements[index] else: if annotation.is_iterator: - return_item = annotation.iterator_item() + return_item = annotation.slice elif annotation.is_generator: - _, _, return_item = annotation.generator_items() + return_item = annotation.slice.elements[2] else: raise ValueError - if isinstance(return_item, Name): + if isinstance(return_item, ExprName): annotation = return_item elif return_item.is_tuple: - annotation = return_item.tuple_item(index) + annotation = return_item.slice.elements[index] else: annotation = return_item else: @@ -447,15 +448,15 @@ def _read_yields_section( with suppress(AttributeError, KeyError, ValueError): annotation = docstring.parent.returns # type: ignore[union-attr] if annotation.is_iterator: - yield_item = annotation.iterator_item() + yield_item = annotation.slice elif annotation.is_generator: - yield_item, _, _ = annotation.generator_items() + yield_item = annotation.slice.elements[0] else: raise ValueError - if isinstance(yield_item, Name): + if isinstance(yield_item, ExprName): annotation = yield_item elif yield_item.is_tuple: - annotation = yield_item.tuple_item(index) + annotation = yield_item.slice.elements[index] else: annotation = yield_item else: @@ -495,11 +496,11 @@ def _read_receives_section( with suppress(AttributeError, KeyError): annotation = docstring.parent.returns # type: ignore[union-attr] if annotation.is_generator: - _, receives_item, _ = annotation.generator_items() - if isinstance(receives_item, Name): + receives_item = annotation.slice.elements[1] + if isinstance(receives_item, ExprName): annotation = receives_item elif receives_item.is_tuple: - annotation = receives_item.tuple_item(index) + annotation = receives_item.slice.elements[index] else: annotation = receives_item else: @@ -569,7 +570,7 @@ def _read_attributes_section( _warn(docstring, new_offset, f"Empty attributes section at line {offset}") return None, new_offset - annotation: str | Name | Expression | None + annotation: str | Expr | None attributes = [] for item in items: name_type = item[0] diff --git a/src/griffe/docstrings/parsers.py b/src/griffe/docstrings/parsers.py index e8f1fb35..d55db9b9 100644 --- a/src/griffe/docstrings/parsers.py +++ b/src/griffe/docstrings/parsers.py @@ -2,33 +2,17 @@ from __future__ import annotations -import enum from typing import TYPE_CHECKING, Any from griffe.docstrings.dataclasses import DocstringSection, DocstringSectionText from griffe.docstrings.google import parse as parse_google from griffe.docstrings.numpy import parse as parse_numpy from griffe.docstrings.sphinx import parse as parse_sphinx +from griffe.enumerations import Parser if TYPE_CHECKING: from griffe.dataclasses import Docstring - -# TODO: assert common denominator / feature parity in all parsers -# - named return, yield, receive -# - exploding return tuple -# - picking yield and receive parts in generator -# - exploding tuple of generator yield part -# - sections titles -# - resolving annotations -class Parser(enum.Enum): - """Enumeration for the different docstring parsers.""" - - google = "google" - sphinx = "sphinx" - numpy = "numpy" - - parsers = { Parser.google: parse_google, Parser.sphinx: parse_sphinx, diff --git a/src/griffe/docstrings/sphinx.py b/src/griffe/docstrings/sphinx.py index 623f4cdb..0c384795 100644 --- a/src/griffe/docstrings/sphinx.py +++ b/src/griffe/docstrings/sphinx.py @@ -27,7 +27,7 @@ if TYPE_CHECKING: from griffe.dataclasses import Docstring - from griffe.expressions import Expression, Name + from griffe.expressions import Expr _warn = warning(__name__) @@ -166,7 +166,7 @@ def _determine_param_annotation( # - "type" directive type # - signature annotation # - none - annotation: str | Name | Expression | None = None + annotation: str | Expr | None = None parsed_param_type = parsed_values.param_types.get(name) if parsed_param_type is not None: @@ -220,7 +220,7 @@ def _read_attribute(docstring: Docstring, offset: int, parsed_values: ParsedValu _warn(docstring, 0, f"Failed to parse field directive from '{parsed_directive.line}'") return parsed_directive.next_index - annotation: str | Name | Expression | None = None + annotation: str | Expr | None = None # Annotation precedence: # - "vartype" directive type @@ -291,7 +291,7 @@ def _read_return(docstring: Docstring, offset: int, parsed_values: ParsedValues) # - "rtype" directive type # - signature annotation # - None - annotation: str | Name | Expression | None + annotation: str | Expr | None if parsed_values.return_type is not None: annotation = parsed_values.return_type else: diff --git a/src/griffe/docstrings/utils.py b/src/griffe/docstrings/utils.py index 3e2b9de2..b17ed220 100644 --- a/src/griffe/docstrings/utils.py +++ b/src/griffe/docstrings/utils.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: from griffe.dataclasses import Docstring - from griffe.expressions import Expression, Name + from griffe.expressions import Expr def warning(name: str) -> Callable[[Docstring, int, str], None]: @@ -46,7 +46,7 @@ def parse_annotation( annotation: str, docstring: Docstring, log_level: LogLevel = LogLevel.error, -) -> str | Name | Expression: +) -> str | Expr: """Parse a string into a true name or expression that can be resolved later. Parameters: diff --git a/src/griffe/encoders.py b/src/griffe/encoders.py index a6439700..1efed11f 100644 --- a/src/griffe/encoders.py +++ b/src/griffe/encoders.py @@ -11,6 +11,7 @@ from pathlib import Path, PosixPath, WindowsPath from typing import TYPE_CHECKING, Any, Callable +from griffe import expressions from griffe.dataclasses import ( Alias, Attribute, @@ -26,7 +27,6 @@ Parameters, ) from griffe.docstrings.dataclasses import DocstringSectionKind -from griffe.expressions import Expression, Name if TYPE_CHECKING: from enum import Enum @@ -116,32 +116,64 @@ def _load_decorators(obj_dict: dict) -> list[Decorator]: return [Decorator(**dec) for dec in obj_dict.get("decorators", [])] -_annotation_loader_map = { - str: lambda _: _, - dict: lambda dct: Name(dct["source"], dct["full"]), - list: lambda lst: Expression(*[_load_annotation(_) for _ in lst]), -} - - -def _load_annotation(annotation: str | dict | list) -> str | Name | Expression: - if annotation is None: - return None - return _annotation_loader_map[type(annotation)](annotation) +def _load_expression(expression: dict) -> expressions.Expr: + cls = getattr(expressions, expression.pop("cls")) + expr = cls(**expression) + if cls is expressions.ExprAttribute: + previous = None + for value in expr.values: + if previous is not None: + value.parent = previous + if isinstance(value, expressions.ExprName): + previous = value + return expr def _load_parameter(obj_dict: dict[str, Any]) -> Parameter: return Parameter( obj_dict["name"], - annotation=_load_annotation(obj_dict["annotation"]), + annotation=obj_dict["annotation"], kind=ParameterKind(obj_dict["kind"]), default=obj_dict["default"], ) +def _attach_parent_to_expr(expr: expressions.Expr | str | None, parent: Module | Class) -> None: + if not isinstance(expr, expressions.Expr): + return + for elem in expr: + if isinstance(elem, expressions.ExprName): + elem.parent = parent + elif isinstance(elem, expressions.ExprAttribute) and isinstance(elem.first, expressions.ExprName): + elem.first.parent = parent + + +def _attach_parent_to_exprs(obj: Class | Function | Attribute, parent: Module | Class) -> None: + if isinstance(obj, Class): + if obj.docstring: + _attach_parent_to_expr(obj.docstring.value, parent) + for decorator in obj.decorators: + _attach_parent_to_expr(decorator.value, parent) + elif isinstance(obj, Function): + if obj.docstring: + _attach_parent_to_expr(obj.docstring.value, parent) + for decorator in obj.decorators: + _attach_parent_to_expr(decorator.value, parent) + for param in obj.parameters: + _attach_parent_to_expr(param.annotation, parent) + _attach_parent_to_expr(param.default, parent) + _attach_parent_to_expr(obj.returns, parent) + elif isinstance(obj, Attribute): + if obj.docstring: + _attach_parent_to_expr(obj.docstring.value, parent) + _attach_parent_to_expr(obj.value, parent) + + def _load_module(obj_dict: dict[str, Any]) -> Module: module = Module(name=obj_dict["name"], filepath=Path(obj_dict["filepath"]), docstring=_load_docstring(obj_dict)) for module_member in obj_dict.get("members", []): module.set_member(module_member.name, module_member) + _attach_parent_to_exprs(module_member, module) module.labels |= set(obj_dict.get("labels", ())) return module @@ -153,11 +185,13 @@ def _load_class(obj_dict: dict[str, Any]) -> Class: endlineno=obj_dict.get("endlineno", None), docstring=_load_docstring(obj_dict), decorators=_load_decorators(obj_dict), - bases=[_load_annotation(_) for _ in obj_dict["bases"]], + bases=obj_dict["bases"], ) for class_member in obj_dict.get("members", []): class_.set_member(class_member.name, class_member) + _attach_parent_to_exprs(class_member, class_) class_.labels |= set(obj_dict.get("labels", ())) + _attach_parent_to_exprs(class_, class_) return class_ @@ -165,7 +199,7 @@ def _load_function(obj_dict: dict[str, Any]) -> Function: function = Function( name=obj_dict["name"], parameters=Parameters(*obj_dict["parameters"]), - returns=_load_annotation(obj_dict["returns"]), + returns=obj_dict["returns"], decorators=_load_decorators(obj_dict), lineno=obj_dict["lineno"], endlineno=obj_dict.get("endlineno", None), @@ -182,7 +216,7 @@ def _load_attribute(obj_dict: dict[str, Any]) -> Attribute: endlineno=obj_dict.get("endlineno", None), docstring=_load_docstring(obj_dict), value=obj_dict.get("value", None), - annotation=_load_annotation(obj_dict.get("annotation", None)), + annotation=obj_dict.get("annotation", None), ) attribute.labels |= set(obj_dict.get("labels", ())) return attribute @@ -206,7 +240,7 @@ def _load_alias(obj_dict: dict[str, Any]) -> Alias: } -def json_decoder(obj_dict: dict[str, Any]) -> dict[str, Any] | Object | Alias | Parameter: +def json_decoder(obj_dict: dict[str, Any]) -> dict[str, Any] | Object | Alias | Parameter | str | expressions.Expr: """Decode dictionaries as data classes. The [`json.loads`][] method walks the tree from bottom to top. @@ -222,6 +256,8 @@ def json_decoder(obj_dict: dict[str, Any]) -> dict[str, Any] | Object | Alias | Returns: An instance of a data class. """ + if "cls" in obj_dict: + return _load_expression(obj_dict) if "kind" in obj_dict: try: kind = Kind(obj_dict["kind"]) diff --git a/src/griffe/enumerations.py b/src/griffe/enumerations.py new file mode 100644 index 00000000..3a2e6b68 --- /dev/null +++ b/src/griffe/enumerations.py @@ -0,0 +1,139 @@ +"""This module contains all the enumerations of the package.""" + +from __future__ import annotations + +import enum + + +class DocstringSectionKind(enum.Enum): + """The possible section kinds.""" + + text = "text" + parameters = "parameters" + other_parameters = "other parameters" + raises = "raises" + warns = "warns" + returns = "returns" + yields = "yields" + receives = "receives" + examples = "examples" + attributes = "attributes" + deprecated = "deprecated" + admonition = "admonition" + + +class ParameterKind(enum.Enum): + """Enumeration of the different parameter kinds. + + Attributes: + positional_only: Positional-only parameter. + positional_or_keyword: Positional or keyword parameter. + var_positional: Variadic positional parameter. + keyword_only: Keyword-only parameter. + var_keyword: Variadic keyword parameter. + """ + + positional_only: str = "positional-only" + positional_or_keyword: str = "positional or keyword" + var_positional: str = "variadic positional" + keyword_only: str = "keyword-only" + var_keyword: str = "variadic keyword" + + +class Kind(enum.Enum): + """Enumeration of the different objects kinds. + + Attributes: + MODULE: The module kind. + CLASS: The class kind. + FUNCTION: The function kind. + ATTRIBUTE: The attribute kind. + """ + + MODULE: str = "module" + CLASS: str = "class" + FUNCTION: str = "function" + ATTRIBUTE: str = "attribute" + ALIAS: str = "alias" + + +class ExplanationStyle(enum.Enum): + """An enumeration of the possible styles for explanations.""" + + ONE_LINE: str = "oneline" + VERBOSE: str = "verbose" + + +class BreakageKind(enum.Enum): + """An enumeration of the possible breakages.""" + + PARAMETER_MOVED: str = "Positional parameter was moved" + PARAMETER_REMOVED: str = "Parameter was removed" + PARAMETER_CHANGED_KIND: str = "Parameter kind was changed" + PARAMETER_CHANGED_DEFAULT: str = "Parameter default was changed" + PARAMETER_CHANGED_REQUIRED: str = "Parameter is now required" + PARAMETER_ADDED_REQUIRED: str = "Parameter was added as required" + RETURN_CHANGED_TYPE: str = "Return types are incompatible" + OBJECT_REMOVED: str = "Public object was removed" + OBJECT_CHANGED_KIND: str = "Public object points to a different kind of object" + ATTRIBUTE_CHANGED_TYPE: str = "Attribute types are incompatible" + ATTRIBUTE_CHANGED_VALUE: str = "Attribute value was changed" + CLASS_REMOVED_BASE: str = "Base class was removed" + + +class Parser(enum.Enum): + """Enumeration for the different docstring parsers.""" + + google = "google" + sphinx = "sphinx" + numpy = "numpy" + + +class ObjectKind(enum.Enum): + """Enumeration for the different kinds of objects.""" + + MODULE: str = "module" + """Modules.""" + CLASS: str = "class" + """Classes.""" + STATICMETHOD: str = "staticmethod" + """Static methods.""" + CLASSMETHOD: str = "classmethod" + """Class methods.""" + METHOD_DESCRIPTOR: str = "method_descriptor" + """Method descriptors.""" + METHOD: str = "method" + """Methods.""" + BUILTIN_METHOD: str = "builtin_method" + """Built-in ethods.""" + COROUTINE: str = "coroutine" + """Coroutines""" + FUNCTION: str = "function" + """Functions.""" + BUILTIN_FUNCTION: str = "builtin_function" + """Built-in functions.""" + CACHED_PROPERTY: str = "cached_property" + """Cached properties.""" + PROPERTY: str = "property" + """Properties.""" + ATTRIBUTE: str = "attribute" + """Attributes.""" + + def __str__(self) -> str: + return self.value + + +class When(enum.Enum): + """This enumeration contains the different times at which an extension is used. + + Attributes: + before_all: For each node, before the visit/inspection. + before_children: For each node, after the visit has started, and before the children visit/inspection. + after_children: For each node, after the children have been visited/inspected, and before finishing the visit/inspection. + after_all: For each node, after the visit/inspection. + """ + + before_all: int = 1 + before_children: int = 2 + after_children: int = 3 + after_all: int = 4 diff --git a/src/griffe/expressions.py b/src/griffe/expressions.py index 4f9d419e..78bdf138 100644 --- a/src/griffe/expressions.py +++ b/src/griffe/expressions.py @@ -2,63 +2,462 @@ from __future__ import annotations -from functools import cached_property -from typing import Any, Callable +import ast +import sys +from dataclasses import dataclass +from dataclasses import fields as getfields +from functools import partial +from itertools import zip_longest +from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence +from griffe.enumerations import ParameterKind from griffe.exceptions import NameResolutionError +from griffe.logger import LogLevel, get_logger +if TYPE_CHECKING: + from pathlib import Path -class Name: - """This class represents a Python object identified by a name in a given scope. + from griffe.dataclasses import Class, Module - Attributes: - source: The name as written in the source code. - """ - def __init__(self, source: str, full: str | Callable, *, first_attr_name: bool = True) -> None: - """Initialize the name. +logger = get_logger(__name__) + + +def _yield(element: str | Expr | tuple[str | Expr, ...]) -> Iterator[str | ExprName | ExprAttribute]: + if isinstance(element, (str, ExprAttribute)): + yield element + elif isinstance(element, tuple): + for elem in element: + yield from _yield(elem) + else: + yield from element + + +def _join( + elements: Iterable[str | Expr | tuple[str | Expr, ...]], + joint: str | Expr, +) -> Iterator[str | ExprName | ExprAttribute]: + it = iter(elements) + try: + yield from _yield(next(it)) + except StopIteration: + return + for element in it: + yield from _yield(joint) + yield from _yield(element) + + +def _field_as_dict( + element: str | bool | Expr | list[str | Expr] | None, + **kwargs: Any, +) -> str | bool | None | list | dict: + if isinstance(element, Expr): + return _expr_as_dict(element, **kwargs) + if isinstance(element, list): + return [_field_as_dict(elem, **kwargs) for elem in element] + return element + + +def _expr_as_dict(expression: Expr, **kwargs: Any) -> dict[str, Any]: + fields = { + field.name: _field_as_dict(getattr(expression, field.name), **kwargs) + for field in sorted(getfields(expression), key=lambda f: f.name) + if field.name != "parent" + } + fields["cls"] = expression.__class__.__name__ + return fields + + +# TODO: merge in decorators once Python 3.9 is dropped +dataclass_opts: dict[str, bool] = {} +if sys.version_info >= (3, 10): + dataclass_opts["slots"] = True + + +@dataclass +class Expr: + """Base class for expressions.""" + + def __str__(self) -> str: + return "".join( + elem if isinstance(elem, str) else elem.name if isinstance(elem, ExprName) else str(elem) for elem in self + ) + + def __iter__(self) -> Iterator[str | ExprName | ExprAttribute]: + yield from () + + def as_dict(self, **kwargs: Any) -> dict[str, Any]: + """Return the expression as a dictionary. Parameters: - source: The name as written in the source code. - full: The full, resolved name in the given scope, or a callable to resolve it later. - first_attr_name: Whether this name is the first in a chain of names representing - an attribute (dot separated strings). + **kwargs: Configuration options (none available yet). + + + Returns: + A dictionary. """ - self.source: str = source - if isinstance(full, str): - self._full: str = full - self._resolver: Callable = lambda: None - else: - self._full = "" - self._resolver = full - self.first_attr_name: bool = first_attr_name - - def __eq__(self, other: Any) -> bool: - if isinstance(other, str): - return self.full == other or self.brief == other - if isinstance(other, Name): - return self.full == other.full - if isinstance(other, Expression): - return self.full == other.source - raise NotImplementedError(f"uncomparable types: {type(self)} and {type(other)}") - - def __repr__(self) -> str: - return f"Name(source={self.source!r}, full={self.full!r})" + return _expr_as_dict(self, **kwargs) - def __str__(self) -> str: - return self.source + @property + def kind(self) -> str: + """The expression kind.""" + return self.__class__.__name__.lower()[4:] @property - def brief(self) -> str: - """Return the brief source name. + def path(self) -> str: + """Path of the expressed name/attribute.""" + return str(self) - Returns: - The last part of the source name. + @property + def canonical_path(self) -> str: + """Path of the expressed name/attribute.""" + return str(self) + + @property + def canonical_name(self) -> str: + """Name of the expressed name/attribute.""" + return self.canonical_path.rsplit(".", 1)[-1] + + @property + def is_classvar(self) -> bool: + """Whether this attribute is annotated with `ClassVar`.""" + return isinstance(self, ExprSubscript) and self.canonical_name == "ClassVar" + + @property + def is_tuple(self) -> bool: + """Whether this expression is a tuple.""" + return isinstance(self, ExprSubscript) and self.canonical_name.lower() == "tuple" + + @property + def is_iterator(self) -> bool: + """Whether this expression is an iterator.""" + return isinstance(self, ExprSubscript) and self.canonical_name == "Iterator" + + @property + def is_generator(self) -> bool: + """Whether this expression is a generator.""" + return isinstance(self, ExprSubscript) and self.canonical_name == "Generator" + + +@dataclass(eq=True, **dataclass_opts) +class ExprAttribute(Expr): + """Attributes like `a.b`.""" + + values: list[str | Expr] + + def __iter__(self) -> Iterator[str | ExprName | ExprAttribute]: + yield from _join(self.values, ".") + + def append(self, value: ExprName) -> None: + """Append a name to this attribute. + + Parameters: + value: The expression name to append. """ - return self.source.rsplit(".", 1)[-1] + if value.parent is None: + value.parent = self.last + self.values.append(value) + + @property + def last(self) -> ExprName: + """The last part of this attribute (on the right).""" + # All values except the first one can *only* be names: + # we can't do `a.(b or c)` or `a."string"`. + return self.values[-1] # type: ignore[return-value] + + @property + def first(self) -> str | Expr: + """The first part of this attribute (on the left).""" + return self.values[0] @property - def full(self) -> str: + def path(self) -> str: + """The path of this attribute.""" + return self.last.path + + @property + def canonical_path(self) -> str: + """The canonical path of this attribute.""" + return self.last.canonical_path + + +@dataclass(eq=True, **dataclass_opts) +class ExprBinOp(Expr): + """Binary operations like `a + b`.""" + + left: str | Expr + operator: str + right: str | Expr + + def __iter__(self) -> Iterator[str | ExprName | ExprAttribute]: + yield from _yield(self.left) + yield f" {self.operator} " + yield from _yield(self.right) + + +@dataclass(eq=True, **dataclass_opts) +class ExprBoolOp(Expr): + """Boolean operations like `a or b`.""" + + operator: str + values: Sequence[str | Expr] + + def __iter__(self) -> Iterator[str | ExprName | ExprAttribute]: + yield from _join(self.values, f" {self.operator} ") + + +@dataclass(eq=True, **dataclass_opts) +class ExprCall(Expr): + """Calls like `f()`.""" + + function: Expr + arguments: Sequence[str | Expr] + + def __iter__(self) -> Iterator[str | ExprName | ExprAttribute]: + yield from self.function + yield "(" + yield from _join(self.arguments, ", ") + yield ")" + + +@dataclass(eq=True, **dataclass_opts) +class ExprCompare(Expr): + """Comparisons like `a > b`.""" + + left: str | Expr + operators: Sequence[str] + comparators: Sequence[str | Expr] + + def __iter__(self) -> Iterator[str | ExprName | ExprAttribute]: + yield from _yield(self.left) + yield " " + yield from _join(zip_longest(self.operators, [], self.comparators, fillvalue=" "), " ") + + +@dataclass(eq=True, **dataclass_opts) +class ExprComprehension(Expr): + """Comprehensions like `a for b in c if d`.""" + + target: str | Expr + iterable: str | Expr + conditions: Sequence[str | Expr] + is_async: bool = False + + def __iter__(self) -> Iterator[str | ExprName | ExprAttribute]: + if self.is_async: + yield "async " + yield "for " + yield from _yield(self.target) + yield " in " + yield from _yield(self.iterable) + if self.conditions: + yield " if " + yield from _join(self.conditions, " if ") + + +@dataclass(eq=True, **dataclass_opts) +class ExprConstant(Expr): + """Constants like `"a"` or `1`.""" + + value: str + + def __iter__(self) -> Iterator[str | ExprName | ExprAttribute]: + yield self.value + + +@dataclass(eq=True, **dataclass_opts) +class ExprDict(Expr): + """Dictionaries like `{"a": 0}`.""" + + keys: Sequence[str | Expr | None] + values: Sequence[str | Expr] + + def __iter__(self) -> Iterator[str | ExprName | ExprAttribute]: + yield "{" + yield from _join( + (("None" if key is None else key, ": ", value) for key, value in zip(self.keys, self.values)), + ", ", + ) + yield "}" + + +@dataclass(eq=True, **dataclass_opts) +class ExprDictComp(Expr): + """Dict comprehensions like `{k: v for k, v in a}`.""" + + key: str | Expr + value: str | Expr + generators: Sequence[Expr] + + def __iter__(self) -> Iterator[str | ExprName | ExprAttribute]: + yield "{" + yield from _yield(self.key) + yield ": " + yield from _yield(self.value) + yield from _join(self.generators, " ") + yield "}" + + +@dataclass(eq=True, **dataclass_opts) +class ExprExtSlice(Expr): + """Extended slice like `a[x:y, z]`.""" + + dims: Sequence[str | Expr] + + def __iter__(self) -> Iterator[str | ExprName | ExprAttribute]: + yield from _join(self.dims, ", ") + + +@dataclass(eq=True, **dataclass_opts) +class ExprFormatted(Expr): + """Formatted string like `{1 + 1}`.""" + + value: str | Expr + + def __iter__(self) -> Iterator[str | ExprName | ExprAttribute]: + yield "{" + yield from _yield(self.value) + yield "}" + + +@dataclass(eq=True, **dataclass_opts) +class ExprGeneratorExp(Expr): + """Generator expressions like `a for b in c for d in e`.""" + + element: str | Expr + generators: Sequence[Expr] + + def __iter__(self) -> Iterator[str | ExprName | ExprAttribute]: + yield from _yield(self.element) + yield " " + yield from _join(self.generators, " ") + + +@dataclass(eq=True, **dataclass_opts) +class ExprIfExp(Expr): + """Conditions like `a if b else c`.""" + + body: str | Expr + test: str | Expr + orelse: str | Expr + + def __iter__(self) -> Iterator[str | ExprName | ExprAttribute]: + yield from _yield(self.body) + yield " if " + yield from _yield(self.test) + yield " else " + yield from _yield(self.orelse) + + +@dataclass(eq=True, **dataclass_opts) +class ExprJoinedStr(Expr): + """Joined strings like `f"a {b} c"`.""" + + values: Sequence[str | Expr] + + def __iter__(self) -> Iterator[str | ExprName | ExprAttribute]: + yield "f'" + yield from _join(self.values, "") + yield "'" + + +@dataclass(eq=True, **dataclass_opts) +class ExprKeyword(Expr): + """Keyword arguments like `a=b`.""" + + name: str + value: str | Expr + + def __iter__(self) -> Iterator[str | ExprName | ExprAttribute]: + yield self.name + yield "=" + yield from _yield(self.value) + + +@dataclass(eq=True, **dataclass_opts) +class ExprVarPositional(Expr): + """Variadic positional parameters like `*args`.""" + + value: Expr + + def __iter__(self) -> Iterator[str | ExprName | ExprAttribute]: + yield "*" + yield from self.value + + +@dataclass(eq=True, **dataclass_opts) +class ExprVarKeyword(Expr): + """Variadic keyword parameters like `**kwargs`.""" + + value: Expr + + def __iter__(self) -> Iterator[str | ExprName | ExprAttribute]: + yield "**" + yield from self.value + + +@dataclass(eq=True, **dataclass_opts) +class ExprLambda(Expr): + """Lambda expressions like `lambda a: a.b`.""" + + parameters: Sequence[ExprParameter] + body: str | Expr + + def __iter__(self) -> Iterator[str | ExprName | ExprAttribute]: + yield "lambda " + yield from _join(self.parameters, ", ") + yield ": " + yield from _yield(self.body) + + +@dataclass(eq=True, **dataclass_opts) +class ExprList(Expr): + """Lists like `[0, 1, 2]`.""" + + elements: Sequence[Expr] + + def __iter__(self) -> Iterator[str | ExprName | ExprAttribute]: + yield "[" + yield from _join(self.elements, ", ") + yield "]" + + +@dataclass(eq=True, **dataclass_opts) +class ExprListComp(Expr): + """List comprehensions like `[a for b in c]`.""" + + element: str | Expr + generators: Sequence[Expr] + + def __iter__(self) -> Iterator[str | ExprName | ExprAttribute]: + yield "[" + yield from _yield(self.element) + yield " " + yield from _join(self.generators, " ") + yield "]" + + +@dataclass(eq=False, **dataclass_opts) +class ExprName(Expr): + """This class represents a Python object identified by a name in a given scope. + + Attributes: + source: The name as written in the source code. + """ + + name: str + parent: str | ExprName | Module | Class | None = None + + def __eq__(self, other: object) -> bool: + if isinstance(other, ExprName): + return self.name == other.name + return NotImplemented + + def __iter__(self) -> Iterator[ExprName]: + yield self + + @property + def path(self) -> str: """Return the full, resolved name. If it was given when creating the name, return that. @@ -68,210 +467,606 @@ def full(self) -> str: Returns: The resolved name or the source. """ - if not self._full: - try: - self._full = self._resolver() or self.source - except NameResolutionError: - # probably a built-in - self._full = self.source - return self._full + if isinstance(self.parent, ExprName): + return f"{self.parent.path}.{self.name}" + return self.name @property - def canonical(self) -> str: + def canonical_path(self) -> str: """Return the canonical name (resolved one, not alias name). Returns: The canonical name. """ - return self.full.rsplit(".", 1)[-1] + if self.parent is None: + return self.name + if isinstance(self.parent, ExprName): + return f"{self.parent.canonical_path}.{self.name}" + if isinstance(self.parent, str): + return f"{self.parent}.{self.name}" + try: + return self.parent.resolve(self.name) + except NameResolutionError: + return self.name - def as_dict(self, **kwargs: Any) -> dict[str, Any]: # noqa: ARG002 - """Return this name's data as a dictionary. - Parameters: - **kwargs: Additional serialization options. +@dataclass(eq=True, **dataclass_opts) +class ExprNamedExpr(Expr): + """Named/assignment expressions like `a := b`.""" - Returns: - A dictionary. - """ - return {"source": self.source, "full": self.full} + target: Expr + value: str | Expr + def __iter__(self) -> Iterator[str | ExprName | ExprAttribute]: + yield "(" + yield from self.target + yield " := " + yield from _yield(self.value) + yield ")" -class Expression(list): - """This class represents a Python expression. - For example, it can represent complex annotations such as: +@dataclass(eq=True, **dataclass_opts) +class ExprParameter(Expr): + """Parameters in function signatures like `a: int = 0`.""" - - `Optional[Dict[str, Tuple[int, bool]]]` - - `str | Callable | list[int]` + kind: str + name: str | None = None + annotation: Expr | None = None + default: Expr | None = None - Expressions are simple lists containing strings, names or expressions. - Each name in the expression can be resolved to its full name within its scope. - """ - def __init__(self, *values: str | Expression | Name) -> None: - """Initialize the expression. +@dataclass(eq=True, **dataclass_opts) +class ExprSet(Expr): + """Sets like `{0, 1, 2}`.""" - Parameters: - *values: The initial values of the expression. - """ - super().__init__() - self.extend(values) + elements: Sequence[str | Expr] - def __str__(self): - return "".join(str(element) for element in self) + def __iter__(self) -> Iterator[str | ExprName | ExprAttribute]: + yield "{" + yield from _join(self.elements, ", ") + yield "}" - @property - def source(self) -> str: - """Return the expression as written in the source. - This property is only useful to the AST utils. +@dataclass(eq=True, **dataclass_opts) +class ExprSetComp(Expr): + """Set comprehensions like `{a for b in c}`.""" - Returns: - The expression as a string. - """ - return str(self) + element: str | Expr + generators: Sequence[Expr] + + def __iter__(self) -> Iterator[str | ExprName | ExprAttribute]: + yield "{" + yield from _yield(self.element) + yield " " + yield from _join(self.generators, " ") + yield "}" + + +@dataclass(eq=True, **dataclass_opts) +class ExprSlice(Expr): + """Slices like `[a:b:c]`.""" + + lower: str | Expr | None = None + upper: str | Expr | None = None + step: str | Expr | None = None + + def __iter__(self) -> Iterator[str | ExprName | ExprAttribute]: + if self.lower is not None: + yield from _yield(self.lower) + yield ":" + if self.upper is not None: + yield from _yield(self.upper) + if self.step is not None: + yield ":" + yield from _yield(self.step) + + +@dataclass(eq=True, **dataclass_opts) +class ExprSubscript(Expr): + """Subscripts like `a[b]`.""" + + left: Expr + slice: Expr # noqa: A003 + + def __iter__(self) -> Iterator[str | ExprName | ExprAttribute]: + yield from self.left + yield "[" + yield from self.slice + yield "]" @property - def full(self) -> str: - """Return the full expression as a string with canonical names (imported ones, not aliases). + def path(self) -> str: + """The path of this subscript's left part.""" + return self.left.path - This property is only useful to the AST utils. + @property + def canonical_path(self) -> str: + """The canonical path of this subscript's left part.""" + return self.left.canonical_path - Returns: - The expression as a string. - """ - parts = [] - for element in self: - if isinstance(element, str): - parts.append(element) - elif isinstance(element, Name): - parts.append(element.full if element.first_attr_name else element.canonical) + +@dataclass(eq=True, **dataclass_opts) +class ExprTuple(Expr): + """Tuples like `(0, 1, 2)`.""" + + elements: Sequence[str | Expr] + implicit: bool = False + + def __iter__(self) -> Iterator[str | ExprName | ExprAttribute]: + if not self.implicit: + yield "(" + yield from _join(self.elements, ", ") + if not self.implicit: + yield ")" + + +@dataclass(eq=True, **dataclass_opts) +class ExprUnaryOp(Expr): + """Unary operations like `-1`.""" + + operator: str + value: str | Expr + + def __iter__(self) -> Iterator[str | ExprName | ExprAttribute]: + yield self.operator + yield from _yield(self.value) + + +@dataclass(eq=True, **dataclass_opts) +class ExprYield(Expr): + """Yield statements like `yield a`.""" + + value: str | Expr | None = None + + def __iter__(self) -> Iterator[str | ExprName | ExprAttribute]: + yield "yield" + if self.value is not None: + yield " " + yield from _yield(self.value) + + +_unary_op_map = { + ast.Invert: "~", + ast.Not: "not ", + ast.UAdd: "+", + ast.USub: "-", +} + +_binary_op_map = { + ast.Add: "+", + ast.BitAnd: "&", + ast.BitOr: "|", + ast.BitXor: "^", + ast.Div: "/", + ast.FloorDiv: "//", + ast.LShift: "<<", + ast.MatMult: "@", + ast.Mod: "%", + ast.Mult: "*", + ast.Pow: "**", + ast.RShift: ">>", + ast.Sub: "-", +} + +_bool_op_map = { + ast.And: "and", + ast.Or: "or", +} + +_compare_op_map = { + ast.Eq: "==", + ast.NotEq: "!=", + ast.Lt: "<", + ast.LtE: "<=", + ast.Gt: ">", + ast.GtE: ">=", + ast.Is: "is", + ast.IsNot: "is not", + ast.In: "in", + ast.NotIn: "not in", +} + + +def _build_attribute(node: ast.Attribute, parent: Module | Class, **kwargs: Any) -> Expr: + left = _build(node.value, parent, **kwargs) + if isinstance(left, ExprAttribute): + left.append(ExprName(node.attr)) + return left + if isinstance(left, ExprName): + return ExprAttribute([left, ExprName(node.attr, left)]) + if isinstance(left, str): + return ExprAttribute([left, ExprName(node.attr, "str")]) + return ExprAttribute([left, ExprName(node.attr)]) + + +def _build_binop(node: ast.BinOp, parent: Module | Class, **kwargs: Any) -> Expr: + return ExprBinOp( + _build(node.left, parent, **kwargs), + _binary_op_map[type(node.op)], + _build(node.right, parent, **kwargs), + ) + + +def _build_boolop(node: ast.BoolOp, parent: Module | Class, **kwargs: Any) -> Expr: + return ExprBoolOp( + _bool_op_map[type(node.op)], + [_build(value, parent, **kwargs) for value in node.values], + ) + + +def _build_call(node: ast.Call, parent: Module | Class, **kwargs: Any) -> Expr: + positional_args = [_build(arg, parent, **kwargs) for arg in node.args] + keyword_args = [_build(kwarg, parent, **kwargs) for kwarg in node.keywords] + return ExprCall(_build(node.func, parent, **kwargs), [*positional_args, *keyword_args]) + + +def _build_compare(node: ast.Compare, parent: Module | Class, **kwargs: Any) -> Expr: + return ExprCompare( + _build(node.left, parent, **kwargs), + [_compare_op_map[type(op)] for op in node.ops], + [_build(comp, parent, **kwargs) for comp in node.comparators], + ) + + +def _build_comprehension(node: ast.comprehension, parent: Module | Class, **kwargs: Any) -> Expr: + return ExprComprehension( + _build(node.target, parent, **kwargs), + _build(node.iter, parent, **kwargs), + [_build(condition, parent, **kwargs) for condition in node.ifs], + is_async=bool(node.is_async), + ) + + +def _build_constant( + node: ast.Constant, + parent: Module | Class, + *, + in_formatted_str: bool = False, + in_joined_str: bool = False, + parse_strings: bool = False, + literal_strings: bool = False, + **kwargs: Any, +) -> str | Expr: + if isinstance(node.value, str): + if in_joined_str and not in_formatted_str: + # We're in a f-string, not in a formatted value, don't keep quotes. + return node.value + if parse_strings and not literal_strings: + # We're in a place where a string could be a type annotation + # (and not in a Literal[...] type annotation). + # We parse the string and build from the resulting nodes again. + # If we fail to parse it (syntax errors), we consider it's a literal string and log a message. + try: + parsed = compile( + node.value, + mode="eval", + filename="", + flags=ast.PyCF_ONLY_AST, + optimize=1, + ) + except SyntaxError: + logger.debug( + f"Tried and failed to parse {node.value!r} as Python code, " + "falling back to using it as a string literal " + "(postponed annotations might help: https://peps.python.org/pep-0563/)", + ) else: - parts.append(element.full) - return "".join(parts) + return _build(parsed.body, parent, **kwargs) # type: ignore[attr-defined] + return {type(...): lambda _: "..."}.get(type(node.value), repr)(node.value) - @property - def kind(self) -> str: - """Return the main type object as a string. - Returns: - The main type of this expression. - """ - return str(self.non_optional).split("[", 1)[0].rsplit(".", 1)[-1].lower() +def _build_dict(node: ast.Dict, parent: Module | Class, **kwargs: Any) -> Expr: + return ExprDict( + [None if key is None else _build(key, parent, **kwargs) for key in node.keys], + [_build(value, parent, **kwargs) for value in node.values], + ) - @property - def without_subscript(self) -> Expression: - """The expression without the subscript part (if any). - For example, `Generic[T]` becomes `Generic`. - """ - parts = [] - for element in self: - if isinstance(element, str) and element == "[": - break - parts.append(element) - return Expression(*parts) +def _build_dictcomp(node: ast.DictComp, parent: Module | Class, **kwargs: Any) -> Expr: + return ExprDictComp( + _build(node.key, parent, **kwargs), + _build(node.value, parent, **kwargs), + [_build(gen, parent, **kwargs) for gen in node.generators], + ) - @property - def is_tuple(self) -> bool: - """Tell whether this expression represents a tuple. - Returns: - True or False. - """ - return self.kind == "tuple" +def _build_formatted(node: ast.FormattedValue, parent: Module | Class, **kwargs: Any) -> Expr: + return ExprFormatted(_build(node.value, parent, in_formatted_str=True, **kwargs)) - @property - def is_iterator(self) -> bool: - """Tell whether this expression represents an iterator. - Returns: - True or False. - """ - return self.kind == "iterator" +def _build_generatorexp(node: ast.GeneratorExp, parent: Module | Class, **kwargs: Any) -> Expr: + return ExprGeneratorExp( + _build(node.elt, parent, **kwargs), + [_build(gen, parent, **kwargs) for gen in node.generators], + ) - @property - def is_generator(self) -> bool: - """Tell whether this expression represents a generator. - Returns: - True or False. - """ - return self.kind == "generator" +def _build_ifexp(node: ast.IfExp, parent: Module | Class, **kwargs: Any) -> Expr: + return ExprIfExp( + _build(node.body, parent, **kwargs), + _build(node.test, parent, **kwargs), + _build(node.orelse, parent, **kwargs), + ) - @property - def is_classvar(self) -> bool: - """Tell whether this expression represents a ClassVar. - Returns: - True or False. - """ - return isinstance(self[0], Name) and self[0].full == "typing.ClassVar" +def _build_joinedstr( + node: ast.JoinedStr, + parent: Module | Class, + *, + in_joined_str: bool = False, # noqa: ARG001 + **kwargs: Any, +) -> Expr: + return ExprJoinedStr([_build(value, parent, in_joined_str=True, **kwargs) for value in node.values]) - @cached_property - def non_optional(self) -> Expression: - """Return the same expression as non-optional. - This will return a new expression without - the `Optional[]` or `| None` parts. +def _build_keyword(node: ast.keyword, parent: Module | Class, **kwargs: Any) -> Expr: + if node.arg is None: + return ExprVarKeyword(_build(node.value, parent, **kwargs)) + return ExprKeyword(node.arg, _build(node.value, parent, **kwargs)) - Returns: - A non-optional expression. - """ - if self[-3:] == ["|", " ", "None"]: - if isinstance(self[0], Expression): - return self[0] - return Expression(self[0]) - if self[:3] == ["None", " ", "|"]: - if isinstance(self[3], Expression): - return self[3] - return Expression(self[3]) - if isinstance(self[0], Name) and self[0].full == "typing.Optional": - if isinstance(self[2], Expression): - return self[2] - return Expression(self[2]) - return self - - def tuple_item(self, nth: int) -> str | Name: - """Return the n-th item of this tuple expression. - Parameters: - nth: The item number. +def _build_lambda(node: ast.Lambda, parent: Module | Class, **kwargs: Any) -> Expr: + # TODO: better parameter handling + return ExprLambda( + [ExprParameter(ParameterKind.positional_or_keyword.value, arg.arg) for arg in node.args.args], + _build(node.body, parent, **kwargs), + ) - Returns: - A string or name. - """ - # 0 1 2 3 - # N , N , N - # 0 1 2 3 4 - return self.non_optional[2][2 * nth] - def tuple_items(self) -> list[Name | Expression]: - """Return a tuple items as a list. +def _build_list(node: ast.List, parent: Module | Class, **kwargs: Any) -> Expr: + return ExprList([_build(el, parent, **kwargs) for el in node.elts]) - Returns: - The tuple items. - """ - return self.non_optional[2][::2] - def iterator_item(self) -> Name | Expression: - """Return the item of an iterator. +def _build_listcomp(node: ast.ListComp, parent: Module | Class, **kwargs: Any) -> Expr: + return ExprListComp(_build(node.elt, parent, **kwargs), [_build(gen, parent, **kwargs) for gen in node.generators]) - Returns: - The iterator item. - """ - return self.non_optional[2] - def generator_items(self) -> tuple[Name | Expression, Name | Expression, Name | Expression]: - """Return the items of a generator. +def _build_name(node: ast.Name, parent: Module | Class, **kwargs: Any) -> Expr: # noqa: ARG001 + return ExprName(node.id, parent) - Returns: - The yield type. - The send/receive type. - The return type. - """ - return self.non_optional[2][0], self[2][2], self[2][4] + +def _build_named_expr(node: ast.NamedExpr, parent: Module | Class, **kwargs: Any) -> Expr: + return ExprNamedExpr(_build(node.target, parent, **kwargs), _build(node.value, parent, **kwargs)) + + +def _build_set(node: ast.Set, parent: Module | Class, **kwargs: Any) -> Expr: + return ExprSet([_build(el, parent, **kwargs) for el in node.elts]) + + +def _build_setcomp(node: ast.SetComp, parent: Module | Class, **kwargs: Any) -> Expr: + return ExprSetComp(_build(node.elt, parent, **kwargs), [_build(gen, parent, **kwargs) for gen in node.generators]) + + +def _build_slice(node: ast.Slice, parent: Module | Class, **kwargs: Any) -> Expr: + return ExprSlice( + None if node.lower is None else _build(node.lower, parent, **kwargs), + None if node.upper is None else _build(node.upper, parent, **kwargs), + None if node.step is None else _build(node.step, parent, **kwargs), + ) + + +def _build_starred(node: ast.Starred, parent: Module | Class, **kwargs: Any) -> Expr: + return ExprVarPositional(_build(node.value, parent, **kwargs)) + + +def _build_subscript( + node: ast.Subscript, + parent: Module | Class, + *, + parse_strings: bool = False, + literal_strings: bool = False, + in_subscript: bool = False, # noqa: ARG001 + **kwargs: Any, +) -> Expr: + left = _build(node.value, parent, **kwargs) + if parse_strings: + if isinstance(left, (ExprAttribute, ExprName)) and left.canonical_path in { + "typing.Literal", + "typing_extensions.Literal", + }: + literal_strings = True + slice = _build( + node.slice, + parent, + parse_strings=True, + literal_strings=literal_strings, + in_subscript=True, + **kwargs, + ) + else: + slice = _build(node.slice, parent, in_subscript=True, **kwargs) + return ExprSubscript(left, slice) + + +def _build_tuple( + node: ast.Tuple, + parent: Module | Class, + *, + in_subscript: bool = False, + **kwargs: Any, +) -> Expr: + return ExprTuple([_build(el, parent, **kwargs) for el in node.elts], implicit=in_subscript) + + +def _build_unaryop(node: ast.UnaryOp, parent: Module | Class, **kwargs: Any) -> Expr: + return ExprUnaryOp(_unary_op_map[type(node.op)], _build(node.operand, parent, **kwargs)) + + +def _build_yield(node: ast.Yield, parent: Module | Class, **kwargs: Any) -> Expr: + return ExprYield(None if node.value is None else _build(node.value, parent, **kwargs)) + + +_node_map: dict[type, Callable[[Any, Module | Class], Expr]] = { + ast.Attribute: _build_attribute, + ast.BinOp: _build_binop, + ast.BoolOp: _build_boolop, + ast.Call: _build_call, + ast.Compare: _build_compare, + ast.comprehension: _build_comprehension, + ast.Constant: _build_constant, # type: ignore[dict-item] + ast.Dict: _build_dict, + ast.DictComp: _build_dictcomp, + ast.FormattedValue: _build_formatted, + ast.GeneratorExp: _build_generatorexp, + ast.IfExp: _build_ifexp, + ast.JoinedStr: _build_joinedstr, + ast.keyword: _build_keyword, + ast.Lambda: _build_lambda, + ast.List: _build_list, + ast.ListComp: _build_listcomp, + ast.Name: _build_name, + ast.NamedExpr: _build_named_expr, + ast.Set: _build_set, + ast.SetComp: _build_setcomp, + ast.Slice: _build_slice, + ast.Starred: _build_starred, + ast.Subscript: _build_subscript, + ast.Tuple: _build_tuple, + ast.UnaryOp: _build_unaryop, + ast.Yield: _build_yield, +} + +# TODO: remove once Python 3.8 support is dropped +if sys.version_info < (3, 9): + + def _build_extslice(node: ast.ExtSlice, parent: Module | Class, **kwargs: Any) -> Expr: + return ExprExtSlice([_build(dim, parent, **kwargs) for dim in node.dims]) + + def _build_index(node: ast.Index, parent: Module | Class, **kwargs: Any) -> Expr: + return _build(node.value, parent, **kwargs) + + _node_map[ast.ExtSlice] = _build_extslice + _node_map[ast.Index] = _build_index + + +def _build(node: ast.AST, parent: Module | Class, **kwargs: Any) -> Expr: + return _node_map[type(node)](node, parent, **kwargs) + + +def get_expression( + node: ast.AST | None, + parent: Module | Class, + *, + parse_strings: bool | None = None, +) -> Expr | None: + """Build an expression from an AST. + + Parameters: + node: The annotation node. + parent: The parent used to resolve the name. + parse_strings: Whether to try and parse strings as type annotations. + + Returns: + A string or resovable name or expression. + """ + if node is None: + return None + if parse_strings is None: + try: + module = parent.module + except ValueError: + parse_strings = False + else: + parse_strings = not module.imports_future_annotations + return _build(node, parent, parse_strings=parse_strings) + + +def safe_get_expression( + node: ast.AST | None, + parent: Module | Class, + *, + parse_strings: bool | None = None, + log_level: LogLevel | None = LogLevel.error, + msg_format: str = "{path}:{lineno}: Failed to get expression from {node_class}: {error}", +) -> Expr | None: + """Safely (no exception) build a resolvable annotation. + + Parameters: + node: The annotation node. + parent: The parent used to resolve the name. + parse_strings: Whether to try and parse strings as type annotations. + log_level: Log level to use to log a message. None to disable logging. + msg_format: A format string for the log message. Available placeholders: + path, lineno, node, error. + + Returns: + A string or resovable name or expression. + """ + try: + return get_expression(node, parent, parse_strings=parse_strings) + except Exception as error: # noqa: BLE001 + if log_level is None: + return None + node_class = node.__class__.__name__ + try: + path: Path | str = parent.relative_filepath + except ValueError: + path = "" + lineno = node.lineno # type: ignore[union-attr] + message = msg_format.format(path=path, lineno=lineno, node_class=node_class, error=error) + getattr(logger, log_level.value)(message) + return None + + +_msg_format = "{path}:{lineno}: Failed to get %s expression from {node_class}: {error}" +get_annotation = partial(get_expression, parse_strings=None) +safe_get_annotation = partial( + safe_get_expression, + parse_strings=None, + msg_format=_msg_format % "annotation", +) +get_base_class = partial(get_expression, parse_strings=False) +safe_get_base_class = partial( + safe_get_expression, + parse_strings=False, + msg_format=_msg_format % "base class", +) +get_condition = partial(get_expression, parse_strings=False) +safe_get_condition = partial( + safe_get_expression, + parse_strings=False, + msg_format=_msg_format % "condition", +) -__all__ = ["Expression", "Name"] +__all__ = [ + "Expr", + "ExprAttribute", + "ExprBinOp", + "ExprBoolOp", + "ExprCall", + "ExprCompare", + "ExprComprehension", + "ExprConstant", + "ExprDict", + "ExprDictComp", + "ExprExtSlice", + "ExprFormatted", + "ExprGeneratorExp", + "ExprIfExp", + "ExprJoinedStr", + "ExprKeyword", + "ExprVarPositional", + "ExprVarKeyword", + "ExprLambda", + "ExprList", + "ExprListComp", + "ExprName", + "ExprNamedExpr", + "ExprParameter", + "ExprSet", + "ExprSetComp", + "ExprSlice", + "ExprSubscript", + "ExprTuple", + "ExprUnaryOp", + "ExprYield", + "get_annotation", + "get_base_class", + "get_condition", + "get_expression", + "safe_get_annotation", + "safe_get_base_class", + "safe_get_condition", + "safe_get_expression", +] diff --git a/src/griffe/extensions/base.py b/src/griffe/extensions/base.py index 20a91581..460a03a9 100644 --- a/src/griffe/extensions/base.py +++ b/src/griffe/extensions/base.py @@ -2,7 +2,6 @@ from __future__ import annotations -import enum import os import sys import warnings @@ -12,6 +11,7 @@ from typing import TYPE_CHECKING, Any, Sequence, Union from griffe.agents.nodes import ast_children, ast_kind +from griffe.enumerations import When from griffe.exceptions import ExtensionNotLoadedError from griffe.importer import dynamic_import @@ -25,22 +25,6 @@ from griffe.dataclasses import Attribute, Class, Function, Module, Object -class When(enum.Enum): - """This enumeration contains the different times at which an extension is used. - - Attributes: - before_all: For each node, before the visit/inspection. - before_children: For each node, after the visit has started, and before the children visit/inspection. - after_children: For each node, after the children have been visited/inspected, and before finishing the visit/inspection. - after_all: For each node, after the visit/inspection. - """ - - before_all: int = 1 - before_children: int = 2 - after_children: int = 3 - after_all: int = 4 - - class VisitorExtension: """Deprecated in favor of `Extension`. The node visitor extension base class, to inherit from.""" diff --git a/src/griffe/loader.py b/src/griffe/loader.py index dd0bbf50..8c9dcdf3 100644 --- a/src/griffe/loader.py +++ b/src/griffe/loader.py @@ -22,7 +22,7 @@ from griffe.collections import LinesCollection, ModulesCollection from griffe.dataclasses import Alias, Kind, Module, Object from griffe.exceptions import AliasResolutionError, CyclicAliasError, LoadingError, UnimportableModuleError -from griffe.expressions import Name +from griffe.expressions import ExprName from griffe.extensions import Extensions from griffe.finder import ModuleFinder, NamespacePackage, Package from griffe.logger import get_logger @@ -191,12 +191,12 @@ def expand_exports(self, module: Module, seen: set | None = None) -> None: return expanded = set() for export in module.exports: - if isinstance(export, Name): - module_path = export.full.rsplit(".", 1)[0] # remove trailing .__all__ + if isinstance(export, ExprName): + module_path = export.canonical_path.rsplit(".", 1)[0] # remove trailing .__all__ try: next_module = self.modules_collection.get_member(module_path) except KeyError: - logger.debug(f"Cannot expand '{export.full}', try pre-loading corresponding package") + logger.debug(f"Cannot expand '{export.canonical_path}', try pre-loading corresponding package") continue if next_module.path not in seen: self.expand_exports(next_module, seen) diff --git a/tests/test_docstrings/test_google.py b/tests/test_docstrings/test_google.py index 352dac6b..e982c42a 100644 --- a/tests/test_docstrings/test_google.py +++ b/tests/test_docstrings/test_google.py @@ -10,7 +10,7 @@ from griffe.dataclasses import Attribute, Class, Docstring, Function, Module, Parameter, Parameters from griffe.docstrings.dataclasses import DocstringSectionKind from griffe.docstrings.utils import parse_annotation -from griffe.expressions import Name +from griffe.expressions import ExprName if TYPE_CHECKING: from tests.test_docstrings.helpers import ParserType @@ -580,19 +580,19 @@ def test_parse_types_in_docstring(parse_google: ParserType) -> None: (returns,) = sections[2].value assert argx.name == "x" - assert argx.annotation.source == "int" - assert argx.annotation.full == "int" + assert argx.annotation.name == "int" + assert argx.annotation.canonical_path == "int" assert argx.description == "X value." assert argx.value is None assert argy.name == "y" - assert argy.annotation.source == "int" - assert argy.annotation.full == "int" + assert argy.annotation.name == "int" + assert argy.annotation.canonical_path == "int" assert argy.description == "Y value." assert argy.value is None - assert returns.annotation.source == "int" - assert returns.annotation.full == "int" + assert returns.annotation.name == "int" + assert returns.annotation.canonical_path == "int" assert returns.description == "Sum X + Y + Z." @@ -632,20 +632,20 @@ def test_parse_optional_type_in_docstring(parse_google: ParserType) -> None: (argz,) = sections[1].value assert argx.name == "x" - assert argx.annotation.source == "int" - assert argx.annotation.full == "int" + assert argx.annotation.name == "int" + assert argx.annotation.canonical_path == "int" assert argx.description == "X value." assert argx.value == "1" assert argy.name == "y" - assert argy.annotation.source == "int" - assert argy.annotation.full == "int" + assert argy.annotation.name == "int" + assert argy.annotation.canonical_path == "int" assert argy.description == "Y value." assert argy.value == "None" assert argz.name == "z" - assert argz.annotation.source == "int" - assert argz.annotation.full == "int" + assert argz.annotation.name == "int" + assert argz.annotation.canonical_path == "int" assert argz.description == "Z value." assert argz.value == "None" @@ -690,17 +690,17 @@ def test_prefer_docstring_types_over_annotations(parse_google: ParserType) -> No (returns,) = sections[2].value assert argx.name == "x" - assert argx.annotation.source == "str" - assert argx.annotation.full == "str" + assert argx.annotation.name == "str" + assert argx.annotation.canonical_path == "str" assert argx.description == "X value." assert argy.name == "y" - assert argy.annotation.source == "str" - assert argy.annotation.full == "str" + assert argy.annotation.name == "str" + assert argy.annotation.canonical_path == "str" assert argy.description == "Y value." - assert returns.annotation.source == "str" - assert returns.annotation.full == "str" + assert returns.annotation.name == "str" + assert returns.annotation.canonical_path == "str" assert returns.description == "Sum X + Y + Z." @@ -885,14 +885,14 @@ def test_retrieve_attributes_annotation_from_parent(parse_google: ParserType) -> b: Whatever. """ parent = Class("cls") - parent["a"] = Attribute("a", annotation=Name("int", "int")) - parent["b"] = Attribute("b", annotation=Name("str", "str")) + parent["a"] = Attribute("a", annotation=ExprName("int")) + parent["b"] = Attribute("b", annotation=ExprName("str")) sections, _ = parse_google(docstring, parent=parent) attributes = sections[1].value assert attributes[0].name == "a" - assert attributes[0].annotation.source == "int" + assert attributes[0].annotation.name == "int" assert attributes[1].name == "b" - assert attributes[1].annotation.source == "str" + assert attributes[1].annotation.name == "str" # ============================================================================================= @@ -947,9 +947,9 @@ def test_parse_yields_tuple_in_iterator_or_generator(parse_google: ParserType, r ) yields = sections[1].value assert yields[0].name == "a" - assert yields[0].annotation.source == "int" + assert yields[0].annotation.name == "int" assert yields[1].name == "b" - assert yields[1].annotation.source == "float" + assert yields[1].annotation.name == "float" @pytest.mark.parametrize( @@ -980,7 +980,7 @@ def test_extract_yielded_type_with_single_return_item(parse_google: ParserType, ), ) yields = sections[1].value - assert yields[0].annotation.source == "int" + assert yields[0].annotation.name == "int" # ============================================================================================= @@ -1007,9 +1007,9 @@ def test_parse_receives_tuple_in_generator(parse_google: ParserType) -> None: ) receives = sections[1].value assert receives[0].name == "a" - assert receives[0].annotation.source == "int" + assert receives[0].annotation.name == "int" assert receives[1].name == "b" - assert receives[1].annotation.source == "float" + assert receives[1].annotation.name == "float" @pytest.mark.parametrize( @@ -1039,7 +1039,7 @@ def test_extract_received_type_with_single_return_item(parse_google: ParserType, ), ) receives = sections[1].value - assert receives[0].annotation.source == "float" + assert receives[0].annotation.name == "float" # ============================================================================================= @@ -1066,9 +1066,9 @@ def test_parse_returns_tuple_in_generator(parse_google: ParserType) -> None: ) returns = sections[1].value assert returns[0].name == "a" - assert returns[0].annotation.source == "int" + assert returns[0].annotation.name == "int" assert returns[1].name == "b" - assert returns[1].annotation.source == "float" + assert returns[1].annotation.name == "float" # ============================================================================================= diff --git a/tests/test_docstrings/test_numpy.py b/tests/test_docstrings/test_numpy.py index c4cdad0e..95c95c06 100644 --- a/tests/test_docstrings/test_numpy.py +++ b/tests/test_docstrings/test_numpy.py @@ -9,18 +9,10 @@ from griffe.dataclasses import Attribute, Class, Docstring, Function, Module, Parameter, Parameters from griffe.docstrings.dataclasses import ( - DocstringAttribute, - DocstringParameter, - DocstringRaise, - DocstringReceive, - DocstringReturn, DocstringSectionKind, - DocstringWarn, - DocstringYield, ) from griffe.docstrings.utils import parse_annotation -from griffe.expressions import Name -from tests.test_docstrings.helpers import assert_attribute_equal, assert_element_equal, assert_parameter_equal +from griffe.expressions import ExprName if TYPE_CHECKING: from tests.test_docstrings.helpers import ParserType @@ -146,7 +138,10 @@ def test_prefer_docstring_type_over_annotation(parse_numpy: ParserType) -> None: parent=Function("func", parameters=Parameters(Parameter("a", annotation="str"))), ) assert len(sections) == 1 - assert_parameter_equal(sections[0].value[0], DocstringParameter("a", description="", annotation=Name("int", "int"))) + param = sections[0].value[0] + assert param.name == "a" + assert param.description == "" + assert param.annotation.name == "int" def test_parse_complex_annotations(parse_numpy: ParserType) -> None: @@ -201,7 +196,7 @@ def test_parse_annotations_in_all_sections(parse_numpy: ParserType, docstring: s docstring = docstring.format(name=name) sections, _ = parse_numpy(docstring, parent=Function("f")) assert len(sections) == 1 - assert sections[0].value[0].annotation == Name(name, name) + assert sections[0].value[0].annotation.name == name def test_dont_crash_on_text_annotations(parse_numpy: ParserType, caplog: pytest.LogCaptureFixture) -> None: @@ -330,7 +325,10 @@ def test_retrieve_annotation_from_parent(parse_numpy: ParserType) -> None: parent=Function("func", parameters=Parameters(Parameter("a", annotation="str"))), ) assert len(sections) == 1 - assert_parameter_equal(sections[0].value[0], DocstringParameter("a", description="", annotation="str")) + param = sections[0].value[0] + assert param.name == "a" + assert param.description == "" + assert param.annotation == "str" def test_deprecated_section(parse_numpy: ParserType) -> None: @@ -377,29 +375,31 @@ def test_returns_section(parse_numpy: ParserType) -> None: sections, _ = parse_numpy(docstring) assert len(sections) == 1 - assert_element_equal( - sections[0].value[0], - DocstringReturn(name="", annotation="list of int", description="A list of integers."), - ) - assert_element_equal( - sections[0].value[1], - DocstringReturn(name="flag", annotation="bool", description="Some kind\nof flag."), - ) - assert_element_equal( - sections[0].value[2], - DocstringReturn(name="x", annotation=None, description="Name only"), - ) + param = sections[0].value[0] + assert param.name == "" + assert param.description == "A list of integers." + assert param.annotation == "list of int" - assert_element_equal( - sections[0].value[3], - DocstringReturn(name="", annotation=None, description="No name or annotation"), - ) + param = sections[0].value[1] + assert param.name == "flag" + assert param.description == "Some kind\nof flag." + assert param.annotation == "bool" - assert_element_equal( - sections[0].value[4], - DocstringReturn(name="", annotation="int", description="Only annotation"), - ) + param = sections[0].value[2] + assert param.name == "x" + assert param.description == "Name only" + assert param.annotation is None + + param = sections[0].value[3] + assert param.name == "" + assert param.description == "No name or annotation" + assert param.annotation is None + + param = sections[0].value[4] + assert param.name == "" + assert param.description == "Only annotation" + assert param.annotation == "int" def test_yields_section(parse_numpy: ParserType) -> None: @@ -420,14 +420,15 @@ def test_yields_section(parse_numpy: ParserType) -> None: sections, _ = parse_numpy(docstring) assert len(sections) == 1 - assert_element_equal( - sections[0].value[0], - DocstringYield(name="", annotation="list of int", description="A list of integers."), - ) - assert_element_equal( - sections[0].value[1], - DocstringYield(name="flag", annotation="bool", description="Some kind\nof flag."), - ) + param = sections[0].value[0] + assert param.name == "" + assert param.description == "A list of integers." + assert param.annotation == "list of int" + + param = sections[0].value[1] + assert param.name == "flag" + assert param.description == "Some kind\nof flag." + assert param.annotation == "bool" def test_receives_section(parse_numpy: ParserType) -> None: @@ -448,14 +449,14 @@ def test_receives_section(parse_numpy: ParserType) -> None: sections, _ = parse_numpy(docstring) assert len(sections) == 1 - assert_element_equal( - sections[0].value[0], - DocstringReceive(name="", annotation="list of int", description="A list of integers."), - ) - assert_element_equal( - sections[0].value[1], - DocstringReceive(name="flag", annotation="bool", description="Some kind\nof flag."), - ) + param = sections[0].value[0] + assert param.name == "" + assert param.description == "A list of integers." + assert param.annotation == "list of int" + param = sections[0].value[1] + assert param.name == "flag" + assert param.description == "Some kind\nof flag." + assert param.annotation == "bool" def test_raises_section(parse_numpy: ParserType) -> None: @@ -473,10 +474,9 @@ def test_raises_section(parse_numpy: ParserType) -> None: sections, _ = parse_numpy(docstring) assert len(sections) == 1 - assert_element_equal( - sections[0].value[0], - DocstringRaise(annotation="RuntimeError", description="There was an issue."), - ) + param = sections[0].value[0] + assert param.description == "There was an issue." + assert param.annotation == "RuntimeError" def test_warns_section(parse_numpy: ParserType) -> None: @@ -494,7 +494,9 @@ def test_warns_section(parse_numpy: ParserType) -> None: sections, _ = parse_numpy(docstring) assert len(sections) == 1 - assert_element_equal(sections[0].value[0], DocstringWarn(annotation="ResourceWarning", description="Heads up.")) + param = sections[0].value[0] + assert param.description == "Heads up." + assert param.annotation == "ResourceWarning" def test_attributes_section(parse_numpy: ParserType) -> None: @@ -515,9 +517,20 @@ def test_attributes_section(parse_numpy: ParserType) -> None: sections, _ = parse_numpy(docstring) assert len(sections) == 1 - assert_attribute_equal(sections[0].value[0], DocstringAttribute(name="a", annotation=None, description="Hello.")) - assert_attribute_equal(sections[0].value[1], DocstringAttribute(name="m", annotation=None, description="")) - assert_attribute_equal(sections[0].value[2], DocstringAttribute(name="z", annotation="int", description="Bye.")) + param = sections[0].value[0] + assert param.name == "a" + assert param.description == "Hello." + assert param.annotation is None + + param = sections[0].value[1] + assert param.name == "m" + assert param.description == "" + assert param.annotation is None + + param = sections[0].value[2] + assert param.name == "z" + assert param.description == "Bye." + assert param.annotation == "int" def test_examples_section(parse_numpy: ParserType) -> None: @@ -624,14 +637,14 @@ def test_retrieve_attributes_annotation_from_parent(parse_numpy: ParserType) -> Whatever. """ parent = Class("cls") - parent["a"] = Attribute("a", annotation=Name("int", "int")) - parent["b"] = Attribute("b", annotation=Name("str", "str")) + parent["a"] = Attribute("a", annotation=ExprName("int")) + parent["b"] = Attribute("b", annotation=ExprName("str")) sections, _ = parse_numpy(docstring, parent=parent) attributes = sections[1].value assert attributes[0].name == "a" - assert attributes[0].annotation.source == "int" + assert attributes[0].annotation.name == "int" assert attributes[1].name == "b" - assert attributes[1].annotation.source == "str" + assert attributes[1].annotation.name == "str" # ============================================================================================= @@ -792,9 +805,9 @@ def test_parse_yields_tuple_in_iterator_or_generator(parse_numpy: ParserType, re ) yields = sections[1].value assert yields[0].name == "a" - assert yields[0].annotation.source == "int" + assert yields[0].annotation.name == "int" assert yields[1].name == "b" - assert yields[1].annotation.source == "float" + assert yields[1].annotation.name == "float" @pytest.mark.parametrize( @@ -827,7 +840,7 @@ def test_extract_yielded_type_with_single_return_item(parse_numpy: ParserType, r ), ) yields = sections[1].value - assert yields[0].annotation.source == "int" + assert yields[0].annotation.name == "int" # ============================================================================================= @@ -857,9 +870,9 @@ def test_parse_receives_tuple_in_generator(parse_numpy: ParserType) -> None: ) receives = sections[1].value assert receives[0].name == "a" - assert receives[0].annotation.source == "int" + assert receives[0].annotation.name == "int" assert receives[1].name == "b" - assert receives[1].annotation.source == "float" + assert receives[1].annotation.name == "float" @pytest.mark.parametrize( @@ -891,7 +904,7 @@ def test_extract_received_type_with_single_return_item(parse_numpy: ParserType, ), ) receives = sections[1].value - assert receives[0].annotation.source == "float" + assert receives[0].annotation.name == "float" # ============================================================================================= @@ -921,9 +934,9 @@ def test_parse_returns_tuple_in_generator(parse_numpy: ParserType) -> None: ) returns = sections[1].value assert returns[0].name == "a" - assert returns[0].annotation.source == "int" + assert returns[0].annotation.name == "int" assert returns[1].name == "b" - assert returns[1].annotation.source == "float" + assert returns[1].annotation.name == "float" # ============================================================================================= diff --git a/tests/test_encoders.py b/tests/test_encoders.py index 3771039b..a4742189 100644 --- a/tests/test_encoders.py +++ b/tests/test_encoders.py @@ -26,13 +26,12 @@ def test_minimal_data_is_enough() -> None: assert reloaded.as_json(full=False) == minimal assert reloaded.as_json(full=True) == full - # also works (but will result in a different type hint) + # Also works (but will result in a different type hint). assert Object.from_json(minimal) # Won't work if the JSON doesn't represent the type requested. - with pytest.raises(TypeError) as err: + with pytest.raises(TypeError, match="provided JSON object is not of type"): Function.from_json(minimal) - assert "provided JSON object is not of type" in str(err.value) # use this function in test_json_schema to ease schema debugging diff --git a/tests/test_expressions.py b/tests/test_expressions.py index d5c23065..7f1b0209 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -2,10 +2,15 @@ from __future__ import annotations +import ast + import pytest +from griffe.dataclasses import Module from griffe.docstrings.parsers import Parser +from griffe.expressions import get_expression from griffe.tests import temporary_visited_module +from tests.test_nodes import syntax_examples @pytest.mark.parametrize( @@ -55,9 +60,7 @@ def test_full_expressions(annotation: str) -> None: """Assert we can transform expressions to their full form without errors.""" code = f"x: {annotation}" with temporary_visited_module(code) as module: - obj = module["x"] - res = obj.annotation.full - assert res == "".join(annotation) + assert str(module["x"].annotation) == annotation def test_resolving_full_names() -> None: @@ -71,5 +74,17 @@ def test_resolving_full_names() -> None: attribute2: mod.Class """, ) as module: - assert module["attribute1"].annotation.full == "package.module.Class" - assert module["attribute2"].annotation.full == "package.module.Class" + assert module["attribute1"].annotation.canonical_path == "package.module.Class" + assert module["attribute2"].annotation.canonical_path == "package.module.Class" + + +@pytest.mark.parametrize("code", syntax_examples) +def test_expressions(code: str) -> None: + """Test building annotations from AST nodes. + + Parameters: + code: An expression (parametrized). + """ + top_node = compile(code, filename="<>", mode="eval", flags=ast.PyCF_ONLY_AST, optimize=2) + expression = get_expression(top_node.body, parent=Module("module")) # type: ignore[attr-defined] + assert str(expression) == code diff --git a/tests/test_functions.py b/tests/test_functions.py index 4ae5b90a..5f1e431f 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -105,14 +105,14 @@ def test_visit_function_variadic_params() -> None: assert len(function.parameters) == 3 param = function.parameters[0] assert param.name == "args" - assert param.annotation.source == "str" - assert param.annotation.full == "str" + assert param.annotation.name == "str" + assert param.annotation.canonical_path == "str" param = function.parameters[1] assert param.annotation is None param = function.parameters[2] assert param.name == "kwargs" - assert param.annotation.source == "int" - assert param.annotation.full == "int" + assert param.annotation.name == "int" + assert param.annotation.canonical_path == "int" def test_visit_function_params_annotations() -> None: @@ -132,11 +132,11 @@ def f_annorations( function = module["f_annorations"] assert len(function.parameters) == 4 param = function.parameters[0] - assert param.annotation.source == "str" - assert param.annotation.full == "str" + assert param.annotation.name == "str" + assert param.annotation.canonical_path == "str" param = function.parameters[1] - assert param.annotation.source == "Any" - assert param.annotation.full == "typing.Any" + assert param.annotation.name == "Any" + assert param.annotation.canonical_path == "typing.Any" param = function.parameters[2] assert str(param.annotation) == "typing.Optional[typing.List[int]]" param = function.parameters[3] diff --git a/tests/test_inspector.py b/tests/test_inspector.py index cd603fae..dabf8db2 100644 --- a/tests/test_inspector.py +++ b/tests/test_inspector.py @@ -8,7 +8,6 @@ import pytest from griffe.agents.inspector import inspect -from griffe.expressions import Name from griffe.tests import temporary_inspected_module, temporary_pypackage @@ -17,8 +16,8 @@ def test_annotations_from_builtin_types() -> None: with temporary_inspected_module("def func(a: int) -> str: pass") as module: func = module["func"] assert func.parameters[0].name == "a" - assert func.parameters[0].annotation == Name("int", full="int") - assert func.returns == Name("str", full="str") + assert func.parameters[0].annotation.name == "int" + assert func.returns.name == "str" def test_annotations_from_classes() -> None: @@ -26,8 +25,12 @@ def test_annotations_from_classes() -> None: with temporary_inspected_module("class A: pass\ndef func(a: A) -> A: pass") as module: func = module["func"] assert func.parameters[0].name == "a" - assert func.parameters[0].annotation == Name("A", full=f"{module.name}.A") - assert func.returns == Name("A", full=f"{module.name}.A") + param = func.parameters[0].annotation + assert param.name == "A" + assert param.canonical_path == f"{module.name}.A" + returns = func.returns + assert returns.name == "A" + assert returns.canonical_path == f"{module.name}.A" def test_class_level_imports() -> None: @@ -41,7 +44,9 @@ def method(self, p: StringIO): """, ) as module: method = module["A.method"] - assert method.parameters["p"].annotation == Name("StringIO", full="io.StringIO") + name = method.parameters["p"].annotation + assert name.name == "StringIO" + assert name.canonical_path == "io.StringIO" def test_missing_dependency() -> None: diff --git a/tests/test_loader.py b/tests/test_loader.py index 74d47252..f28e21e3 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -6,7 +6,7 @@ from textwrap import dedent from typing import TYPE_CHECKING -from griffe.expressions import Name +from griffe.expressions import ExprName from griffe.loader import GriffeLoader from griffe.tests import temporary_pyfile, temporary_pypackage @@ -70,9 +70,9 @@ def test_dont_shortcut_alias_chain_after_expanding_wildcards() -> None: child = package["mod_a.Child"] assert child.bases base = child.bases[0] - assert isinstance(base, Name) - assert base.source == "Base" - assert base.full == "package.mod_b.Base" + assert isinstance(base, ExprName) + assert base.name == "Base" + assert base.canonical_path == "package.mod_b.Base" def test_dont_overwrite_lower_member_when_expanding_wildcard() -> None: @@ -182,7 +182,7 @@ def function2(self, arg1: float) -> float: ... assert "CONST" in mod.members const = mod["CONST"] assert const.value == "0" - assert const.annotation.source == "int" + assert const.annotation.name == "int" assert "Class" in mod.members class_ = mod["Class"] @@ -190,7 +190,7 @@ def function2(self, arg1: float) -> float: ... assert "class_attr" in class_.members class_attr = class_["class_attr"] assert class_attr.value == "True" - assert class_attr.annotation.source == "bool" + assert class_attr.annotation.name == "bool" assert "function1" in class_.members function1 = class_["function1"] @@ -198,8 +198,8 @@ def function2(self, arg1: float) -> float: ... assert "function2" in class_.members function2 = class_["function2"] - assert function2.returns.source == "float" - assert function2.parameters["arg1"].annotation.source == "float" + assert function2.returns.name == "float" + assert function2.parameters["arg1"].annotation.name == "float" assert function2.parameters["arg1"].default == "2.2" diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 16ab48d8..0a74ee07 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -9,7 +9,7 @@ import pytest from griffe.agents.nodes import get_value, relative_to_absolute -from griffe.expressions import Expression, Name +from griffe.expressions import Expr, ExprName from griffe.tests import module_vtree, temporary_visited_module syntax_examples = [ @@ -56,9 +56,9 @@ "f'a {round(key, 2)} {z}'", # slices "o[x]", - "o[x,y]", + "o[x, y]", "o[x:y]", - "o[x:y,z]", + "o[x:y, z]", "o[x, y(z)]", # walrus operator "a if (a := b) else c", @@ -155,18 +155,6 @@ def test_building_expressions_from_nodes(code: str) -> None: assert value == code.replace(", ", ",") -def _flat(expression: str | Name | Expression) -> list[str | Name]: - if not isinstance(expression, Expression): - return [expression] - items = [] - for item in expression: - if isinstance(item, Expression): - items.extend(_flat(item)) - else: - items.append(item) - return items - - @pytest.mark.parametrize( ("code", "has_name"), [ @@ -186,13 +174,13 @@ def test_forward_references(code: str, has_name: bool) -> None: has_name: Whether the annotation should contain a Name rather than a string. """ with temporary_visited_module(code) as module: - flat = _flat(module["a"].annotation) + annotation = module["a"].annotation if has_name: - assert any(isinstance(item, Name) and item.source == "A" for item in flat) - assert all(not (isinstance(item, str) and item == "A") for item in flat) + assert any(isinstance(item, ExprName) and item.name == "A" for item in annotation) + assert all(not (isinstance(item, str) and item == "A") for item in annotation) else: - assert any(isinstance(item, str) and item == "'A'" for item in flat) - assert all(not (isinstance(item, Name) and item.source == "A") for item in flat) + assert "A" in annotation + assert all(not (isinstance(item, ExprName) and item.name == "A") for item in annotation) @pytest.mark.parametrize( @@ -268,10 +256,10 @@ def kwargs(self) -> 'dict[str, Any] | None': """, ) as module: init_args_annotation = module["ArgsKwargs.__init__"].parameters["args"].annotation - assert isinstance(init_args_annotation, Expression) + assert isinstance(init_args_annotation, Expr) assert init_args_annotation.is_tuple kwargs_return_annotation = module["ArgsKwargs.kwargs"].annotation - assert isinstance(kwargs_return_annotation, Expression) + assert isinstance(kwargs_return_annotation, Expr) def test_parsing_dynamic_base_classes(caplog: pytest.LogCaptureFixture) -> None: diff --git a/tests/test_visitor.py b/tests/test_visitor.py index d2abd215..091e5f87 100644 --- a/tests/test_visitor.py +++ b/tests/test_visitor.py @@ -170,10 +170,10 @@ def absolute(self, path: str | Path) -> str | Path: with temporary_visited_module(code) as module: overloads = module["A.absolute"].overloads assert len(overloads) == 2 - assert overloads[0].parameters["path"].annotation.source == "str" - assert overloads[1].parameters["path"].annotation.source == "Path" - assert overloads[0].returns.source == "str" - assert overloads[1].returns.source == "Path" + assert overloads[0].parameters["path"].annotation.name == "str" + assert overloads[1].parameters["path"].annotation.name == "Path" + assert overloads[0].returns.name == "str" + assert overloads[1].returns.name == "Path" @pytest.mark.parametrize( @@ -317,27 +317,27 @@ def __init__(self) -> None: self.b: bytes """, ) as module: - assert module["C.w"].annotation.full == "str" + assert module["C.w"].annotation.canonical_path == "str" assert module["C.w"].labels == {"class-attribute"} assert module["C.w"].value == "'foo'" - assert module["C.x"].annotation.full == "int" + assert module["C.x"].annotation.canonical_path == "int" assert module["C.x"].labels == {"class-attribute"} - assert module["C.y"].annotation.full == "str" + assert module["C.y"].annotation.canonical_path == "str" assert module["C.y"].labels == {"instance-attribute"} assert module["C.y"].value == "''" - assert module["C.z"].annotation.full == "int" + assert module["C.z"].annotation.canonical_path == "int" assert module["C.z"].labels == {"class-attribute", "instance-attribute"} assert module["C.z"].value == "5" # This is syntactically valid, but semantically invalid - assert module["C.a"].annotation[0].full == "typing.ClassVar" - assert module["C.a"].annotation[2].full == "float" + assert module["C.a"].annotation.canonical_path == "typing.ClassVar" + assert module["C.a"].annotation.slice.canonical_path == "float" assert module["C.a"].labels == {"instance-attribute"} - assert module["C.b"].annotation.full == "bytes" + assert module["C.b"].annotation.canonical_path == "bytes" assert module["C.b"].labels == {"instance-attribute"}