Skip to content

Commit

Permalink
feat: Add async loader
Browse files Browse the repository at this point in the history
  • Loading branch information
pawamoy committed Sep 18, 2021
1 parent 2fb19c2 commit 3218bd0
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 57 deletions.
88 changes: 62 additions & 26 deletions src/griffe/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from __future__ import annotations

import argparse
import asyncio
import json
import logging
import sys
Expand All @@ -22,7 +23,7 @@
from griffe.encoders import Encoder
from griffe.extended_ast import extend_ast
from griffe.extensions import Extensions
from griffe.loader import GriffeLoader
from griffe.loader import AsyncGriffeLoader, GriffeLoader
from griffe.logger import get_logger

logger = get_logger(__name__)
Expand All @@ -36,38 +37,80 @@ def _print_data(data, output_file):
print(data, file=fd)


async def _load_packages_async(packages, extensions, search_paths):
loader = AsyncGriffeLoader(extensions=extensions)
loaded = {}
for package in packages:
logger.info(f"Loading package {package}")
try:
module = await loader.load_module(package, search_paths=search_paths)
except ModuleNotFoundError:
logger.error(f"Could not find package {package}")
else:
loaded[module.name] = module
return loaded


def _load_packages(packages, extensions, search_paths):
loader = GriffeLoader(extensions=extensions)
loaded = {}
for package in packages:
logger.info(f"Loading package {package}")
try:
module = loader.load_module(package, search_paths=search_paths)
except ModuleNotFoundError:
logger.error(f"Could not find package {package}")
else:
loaded[module.name] = module
return loaded


def get_parser() -> argparse.ArgumentParser:
"""
Return the program argument parser.
Returns:
The argument parser for the program.
"""
parser = argparse.ArgumentParser(prog="griffe")
parser = argparse.ArgumentParser(prog="griffe", add_help=False)
parser.add_argument(
"-s",
"--search",
action="append",
type=Path,
help="Paths to search packages into.",
"-A",
"--async-loader",
action="store_true",
help="Whether to read files on disk asynchronously. "
"Very large projects with many files will be processed faster. "
"Small projects with a few files will not see any speed up.",
)
parser.add_argument(
"-a",
"--append-search",
"--append-sys-path",
action="store_true",
help="Whether to append sys.path to specified search paths.",
help="Whether to append sys.path to search paths specified with -s.",
)
parser.add_argument(
"-h",
"--help",
action="help",
help="Show this help message and exit.",
)
parser.add_argument(
"-o",
"--output",
default=sys.stdout,
help="Output file. Supports templating to output each package in its own file, with {{package}}.",
)
parser.add_argument(
"-s",
"--search",
action="append",
type=Path,
help="Paths to search packages into.",
)
parser.add_argument("packages", metavar="PACKAGE", nargs="+", help="Packages to find and parse.")
return parser


def main(args: list[str] | None = None) -> int:
def main(args: list[str] | None = None) -> int: # noqa: WPS231
"""
Run the main program.
Expand All @@ -82,7 +125,7 @@ def main(args: list[str] | None = None) -> int:
parser = get_parser()
opts: argparse.Namespace = parser.parse_args(args) # type: ignore

logging.basicConfig(format="%(levelname)-10s %(message)s", level=logging.INFO) # noqa: WPS323
logging.basicConfig(format="%(levelname)-10s %(message)s", level=logging.WARNING) # noqa: WPS323

output = opts.output

Expand All @@ -91,25 +134,18 @@ def main(args: list[str] | None = None) -> int:
per_package_output = True

search = opts.search
if opts.append_search:
if opts.append_sys_path:
search.extend(sys.path)

extend_ast()

extensions = Extensions()
loader = GriffeLoader(extensions=extensions)
packages = {}
success = True

for package in opts.packages:
logger.info(f"Loading package {package}")
try:
module = loader.load_module(package, search_paths=search)
except ModuleNotFoundError:
logger.error(f"Could not find package {package}")
success = False
else:
packages[module.name] = module
if opts.async_loader:
loop = asyncio.get_event_loop()
coroutine = _load_packages_async(opts.packages, extensions=extensions, search_paths=search)
packages = loop.run_until_complete(coroutine)
else:
packages = _load_packages(opts.packages, extensions=extensions, search_paths=search)

if per_package_output:
for package_name, data in packages.items():
Expand All @@ -119,4 +155,4 @@ def main(args: list[str] | None = None) -> int:
serialized = json.dumps(packages, cls=Encoder, indent=2, full=True)
_print_data(serialized, output)

return 0 if success else 1
return 0 if len(packages) == len(opts.packages) else 1
161 changes: 130 additions & 31 deletions src/griffe/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,42 @@

from __future__ import annotations

import asyncio
import sys
from pathlib import Path
from typing import Iterator
from typing import Iterator, Tuple

from griffe.collections import lines_collection
from griffe.dataclasses import Module
from griffe.extensions import Extensions
from griffe.logger import get_logger
from griffe.visitor import visit

NamePartsType = Tuple[str, ...]
NamePartsAndPathType = Tuple[NamePartsType, Path]

logger = get_logger(__name__)


class GriffeLoader:
"""The griffe loader, allowing to load data from modules.
async def _read_async(path):
async with aopen(path) as fd:
return await fd.read()

Attributes:
extensions: The extensions to use.
"""

async def _fallback_read_async(path):
logger.warning("aiofiles is not installed, fallback to blocking read")
return path.read_text()


try:
from aiofiles import open as aopen # type: ignore
except ModuleNotFoundError:
read_async = _fallback_read_async
else:
read_async = _read_async


class _BaseGriffeLoader:
def __init__(self, extensions: Extensions | None = None) -> None:
"""Initialize the loader.
Expand All @@ -40,35 +56,51 @@ def __init__(self, extensions: Extensions | None = None) -> None:
"""
self.extensions = extensions or Extensions()

def _module_name_and_path(
self,
module: str | Path,
search_paths: list[str | Path] | None = None,
) -> tuple[str, Path]:
if isinstance(module, Path):
# programatically passed a Path, try only that
module_name, module_path = module_name_path(module)
else:
# passed a string (from CLI or Python code), try both
try:
module_name, module_path = module_name_path(Path(module))
except FileNotFoundError:
module_name = module
module_path = find_module(module_name, search_paths=search_paths)
return module_name, module_path


class GriffeLoader(_BaseGriffeLoader):
"""The Griffe loader, allowing to load data from modules.
Attributes:
extensions: The extensions to use.
"""

def load_module(
self,
module: str | Path,
recursive: bool = True,
submodules: bool = True,
search_paths: list[str | Path] | None = None,
) -> Module:
"""Load a module.
Arguments:
module: The module name or path.
recursive: Whether to recurse on the submodules.
submodules: Whether to recurse on the submodules.
search_paths: The paths to search into.
Returns:
A module.
"""
if isinstance(module, Path):
# programatically passed a Path, try only that
module_name, module_path = module_name_path(module)
else:
# passed a string (from CLI or Python code), try both
try:
module_name, module_path = module_name_path(Path(module))
except FileNotFoundError:
module_name = module
module_path = find_module(module_name, search_paths=search_paths)
return self._load_module_path(module_name, module_path, recursive=recursive)
module_name, module_path = self._module_name_and_path(module, search_paths)
return self._load_module_path(module_name, module_path, submodules=submodules)

def _load_module_path(self, module_name, module_path, recursive=True):
def _load_module_path(self, module_name: str, module_path: Path, submodules: bool = True) -> Module:
logger.debug(f"Loading path {module_path}")
code = module_path.read_text()
lines_collection[module_path] = code.splitlines(keepends=False)
Expand All @@ -78,19 +110,83 @@ def _load_module_path(self, module_name, module_path, recursive=True):
code=code,
extensions=self.extensions,
)
if recursive:
for subparts, subpath in sorted(iter_submodules(module_path), key=_module_depth):
parent_parts = subparts[:-1]
try:
member_parent = module[parent_parts]
except KeyError:
logger.debug(f"Skipping (not importable) {subpath}")
continue
member_parent[subparts[-1]] = self._load_module_path(subparts[-1], subpath, recursive=False)
if submodules:
self._load_submodules(module)
return module

def _load_submodules(self, module: Module) -> None:
for subparts, subpath in sorted(iter_submodules(module.filepath), key=_module_depth):
self._load_submodule(module, subparts, subpath)

def _load_submodule(self, module: Module, subparts: NamePartsType, subpath: Path) -> None:
parent_parts = subparts[:-1]
try:
member_parent = module[parent_parts]
except KeyError:
logger.debug(f"Skipping (not importable) {subpath}")
else:
member_parent[subparts[-1]] = self._load_module_path(subparts[-1], subpath, submodules=False)


class AsyncGriffeLoader(_BaseGriffeLoader):
"""The asynchronous Griffe loader, allowing to load data from modules.
Attributes:
extensions: The extensions to use.
"""

async def load_module(
self,
module: str | Path,
submodules: bool = True,
search_paths: list[str | Path] | None = None,
) -> Module:
"""Load a module.
Arguments:
module: The module name or path.
submodules: Whether to recurse on the submodules.
search_paths: The paths to search into.
Returns:
A module.
"""
module_name, module_path = self._module_name_and_path(module, search_paths)
return await self._load_module_path(module_name, module_path, submodules=submodules)

async def _load_module_path(self, module_name: str, module_path: Path, submodules: bool = True) -> Module:
logger.debug(f"Loading path {module_path}")
code = await read_async(module_path)
lines_collection[module_path] = code.splitlines(keepends=False)
module = visit(
module_name,
filepath=module_path,
code=code,
extensions=self.extensions,
)
if submodules:
await self._load_submodules(module)
return module

async def _load_submodules(self, module: Module) -> None:
await asyncio.gather(
*[
self._load_submodule(module, subparts, subpath)
for subparts, subpath in sorted(iter_submodules(module.filepath), key=_module_depth)
]
)

async def _load_submodule(self, module: Module, subparts: NamePartsType, subpath: Path) -> None:
parent_parts = subparts[:-1]
try:
member_parent = module[parent_parts]
except KeyError:
logger.debug(f"Skipping (not importable) {subpath}")
else:
member_parent[subparts[-1]] = await self._load_module_path(subparts[-1], subpath, submodules=False)


def _module_depth(name_parts_and_path):
def _module_depth(name_parts_and_path: NamePartsAndPathType) -> int:
return len(name_parts_and_path[0])


Expand All @@ -117,12 +213,15 @@ def module_name_path(path: Path) -> tuple[str, Path]:
raise FileNotFoundError
if path.exists():
if path.stem == "__init__":
if path.parent.is_absolute():
return path.parent.name, path
return path.parent.resolve().name, path
return path.stem, path
raise FileNotFoundError


# credits to @NiklasRosenstein and the docspec project
# TODO: possible optimization by caching elements of search directories
def find_module(module_name: str, search_paths: list[str | Path] | None = None) -> Path:
"""Find a module in a given list of paths or in `sys.path`.
Expand Down Expand Up @@ -156,7 +255,7 @@ def find_module(module_name: str, search_paths: list[str | Path] | None = None)
raise ModuleNotFoundError(module_name)


def iter_submodules(path) -> Iterator[tuple[list[str], Path]]: # noqa: WPS234
def iter_submodules(path: Path) -> Iterator[NamePartsAndPathType]: # noqa: WPS234
"""Iterate on a module's submodules, if any.
Arguments:
Expand Down

0 comments on commit 3218bd0

Please sign in to comment.