Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

apidoc/python: Fix xrefs to type parameters #372

Merged
merged 1 commit into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/type_param_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,11 @@ def get(self, key: K, default: V) -> V: ...
def get(self, key: K, default: T) -> Union[V, T]: ...

def get(self, key: K, default=None):
"""Return the mapped value, or the specified default."""
"""Return the mapped value, or the specified default.
:param key: Key to retrieve.
:param default: Default value to return if key is not present.
"""
...

def __len__(self) -> int:
Expand Down
43 changes: 31 additions & 12 deletions sphinx_immaterial/apidoc/python/parameter_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,16 +253,29 @@ def get_objects(
PythonDomain.get_objects = get_objects # type: ignore[assignment]


def _fix_pending_xrefs_to_type_params(
type_param_symbols: dict[str, str], parent: docutils.nodes.Node
) -> None:
for xref in parent.findall(condition=sphinx.addnodes.pending_xref):
if xref["refdomain"] == "py" and xref["reftype"] in ("class", "param"):
p_symbol = type_param_symbols.get(xref["reftarget"])
if p_symbol is not None:
xref["reftarget"] = p_symbol
xref["refspecific"] = False


def _add_parameter_links_to_signature(
env: sphinx.environment.BuildEnvironment,
signode: sphinx.addnodes.desc_signature,
type_param_symbol_prefix: str,
function_param_symbol_prefix: str,
) -> Dict[str, docutils.nodes.Element]:
) -> tuple[dict[str, docutils.nodes.Element], dict[str, str]]:
"""Cross-links parameter names in signature to parameter objects.

Returns:
Map of parameter name to original (not linked) parameter node.
Tuple of:
- Map of parameter name to original (not linked) parameter node.
- Map of type parameter name to parameter object symbol.
"""
sig_param_nodes: Dict[str, docutils.nodes.Element] = {}

Expand Down Expand Up @@ -336,15 +349,13 @@ def _collect_parameters(
refnode["implicit_sig_param_ref"] = True
name_node.replace_self(refnode)

# Also cross-link references to type parameters in annotations.
for xref in signode.findall(condition=sphinx.addnodes.pending_xref):
if xref["refdomain"] == "py" and xref["reftype"] in ("class", "param"):
p_symbol = type_param_symbols.get(xref["reftarget"])
if p_symbol is not None:
xref["reftarget"] = p_symbol
xref["refspecific"] = False
if type_param_symbols:
# Also cross-link references to type parameters in annotations.
_fix_pending_xrefs_to_type_params(type_param_symbols, signode)
for parent in sig_param_nodes.values():
_fix_pending_xrefs_to_type_params(type_param_symbols, parent)

return sig_param_nodes
return sig_param_nodes, type_param_symbols


def _collate_parameter_symbols(
Expand Down Expand Up @@ -550,6 +561,8 @@ def _cross_link_parameters(
env = app.env
assert isinstance(env, sphinx.environment.BuildEnvironment)

type_param_symbols: dict[str, str] = {}

# Collect the docutils nodes corresponding to the declarations of the
# parameters in each signature, and turn the parameter names into
# cross-links to the parameter description.
Expand All @@ -559,9 +572,11 @@ def _cross_link_parameters(
# e.g. `x : int = 10` rather than just `x`.
sig_param_nodes_for_signature = []
for signode, symbol, function_symbol in zip(signodes, symbols, function_symbols):
sig_param_nodes_for_signature.append(
_add_parameter_links_to_signature(env, signode, symbol, function_symbol)
sig_param_nodes, sig_type_param_symbols = _add_parameter_links_to_signature(
env, signode, symbol, function_symbol
)
sig_param_nodes_for_signature.append(sig_param_nodes)
type_param_symbols.update(sig_type_param_symbols)

# Find all parameter descriptions in the object description body, and mark
# them as the target for cross links to that parameter. Also substitute in
Expand All @@ -576,6 +591,10 @@ def _cross_link_parameters(
noindex=noindex,
)

# Fix any remaining references to type parameters.
if type_param_symbols:
_fix_pending_xrefs_to_type_params(type_param_symbols, content)

if not noindex:
py = cast(sphinx.domains.python.PythonDomain, env.get_domain("py"))

Expand Down
31 changes: 30 additions & 1 deletion tests/python_apigen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,15 @@ def apigen_make_app(tmp_path: pathlib.Path, make_app):

def make(extra_conf: str = "", **kwargs):
(tmp_path / "conf.py").write_text(conf + extra_conf, encoding="utf-8")
(tmp_path / "index.rst").write_text("", encoding="utf-8")
(tmp_path / "index.rst").write_text(
"""
.. python-apigen-group:: Public Members

.. python-apigen-group:: Classes

""",
encoding="utf-8",
)
return make_app(srcdir=SphinxPath(str(tmp_path)), **kwargs)

yield make
Expand Down Expand Up @@ -157,3 +165,24 @@ def test_pure_python_property(apigen_make_app):
assert member.name == "baz"
assert len(member.siblings) == 1
assert member.siblings[0].name == "bar"


@pytest.mark.skipif(
sphinx.version_info < (7, 1),
reason=f"Type parameters are not supported by Sphinx {sphinx.version_info}",
)
def test_type_params(apigen_make_app):
"""Tests that references to type parameters are all resolved."""
testmod = "python_apigen_test_modules.type_params"
app = apigen_make_app(
confoverrides=dict(
python_apigen_modules={
testmod: "api/",
},
nitpicky=True,
),
)
app.build()
print(app._status.getvalue())
print(app._warning.getvalue())
assert not app._warning.getvalue()
29 changes: 29 additions & 0 deletions tests/python_apigen_test_modules/type_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import TypeVar

T = TypeVar("T")


def foo(x: T) -> T:
"""Foo function.

:param x: Something or other.
"""
return x


def bar(x: T) -> T:
return x


class C:
def get(self, x: T, y: T) -> T:
"""Get function.

:param x: Something or other.
:param y: Another param.
:type y: T
"""
return x


__all__ = ["foo", "bar", "C"]
Loading