diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dfb871f..c9f6785 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,31 +6,20 @@ repos: - id: check-yaml - id: check-toml - - repo: https://github.com/pycqa/isort - rev: 5.12.0 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.0.278 hooks: - - id: isort - additional_dependencies: [toml] + - id: ruff + args: [--fix, --exit-non-zero-on-fix] - repo: https://github.com/psf/black - rev: 23.1.0 + rev: 23.7.0 hooks: - id: black - repo: https://github.com/asottile/blacken-docs - rev: 1.13.0 + rev: 1.15.0 hooks: - id: blacken-docs - additional_dependencies: [black==23.1.0] - args: [-l, '79', -t, py38] - - - repo: https://github.com/pycqa/pydocstyle - rev: 6.3.0 - hooks: - - id: pydocstyle - additional_dependencies: [tomli] - - - repo: https://github.com/pycqa/flake8 - rev: 6.0.0 - hooks: - - id: flake8 + additional_dependencies: [black==23.7.0] + args: [-l, '79', -t, py310] diff --git a/changelog.d/20230718_175341_jsick_ruff.md b/changelog.d/20230718_175341_jsick_ruff.md new file mode 100644 index 0000000..21aa7e9 --- /dev/null +++ b/changelog.d/20230718_175341_jsick_ruff.md @@ -0,0 +1,4 @@ +### Other changes + +- Use ruff for linting the codebase, replacing flake8 and isort. +- Improve the codebase following ruff's recommendations. diff --git a/docs/_rst_epilog.rst b/docs/_rst_epilog.rst index 828e6bf..e8735b0 100644 --- a/docs/_rst_epilog.rst +++ b/docs/_rst_epilog.rst @@ -10,3 +10,4 @@ .. _Schema Evolution and Compatibility: https://docs.confluent.io/current/schema-registry/avro.html .. _Strimzi: https://strimzi.io .. _tox: https://tox.readthedocs.io/en/latest/ +.. _scriv: https://scriv.readthedocs.io/en/latest/ diff --git a/docs/dev/development.rst b/docs/dev/development.rst index 6276aad..9413768 100644 --- a/docs/dev/development.rst +++ b/docs/dev/development.rst @@ -43,8 +43,8 @@ Pre-commit hooks The pre-commit hooks, which are automatically installed by running the :command:`make init` command on :ref:`set up `, ensure that files are valid and properly formatted. Some pre-commit hooks automatically reformat code: -``isort`` - Automatically sorts imports in Python modules. +``ruff`` + Automatically fixes common issues in code and sorts imports. ``black`` Automatically formats Python code. @@ -99,29 +99,17 @@ Updating the change log ======================= Each pull request should update the change log (:file:`CHANGELOG.md`). -Add a description of new features and fixes as list items under a section at the top of the change log called "Unreleased:" +The change log is maintained with scriv_. -.. code-block:: md +To create a new change log fragment, run: - ## Unreleased - - - Description of the feature or fix. - -If the next version is known (because Kafkit's main branch is being prepared for a new major or minor version), the section may contain that version information: - -.. code-block:: md - - ## X.Y.0 (unreleased) - - - Description of the feature or fix. - -If the exact version and release date is known (:doc:`because a release is being prepared `), the section header is formatted as: - -.. code-block:: rst +.. code-block:: sh - ## X.Y.0 (YYYY-MM-DD) + scriv create - - Description of the feature or fix. +This creates a new file in the :file:`changelog.d` directory. +Edit this file to describe the changes in the pull request. +If sections don't apply to the change you can delete them. .. _style-guide: @@ -131,7 +119,7 @@ Style guide Code ---- -- The code style follows :pep:`8`, though in practice lean on Black and isort to format the code for you. +- The code style follows :pep:`8`, though in practice lean on Black and ruff to format the code for you. - Use :pep:`484` type annotations. The ``tox -e typing`` test environment, which runs mypy_, ensures that the project's types are consistent. diff --git a/docs/dev/release.rst b/docs/dev/release.rst index abf4b73..4cdff44 100644 --- a/docs/dev/release.rst +++ b/docs/dev/release.rst @@ -26,11 +26,16 @@ Release tags are semantic version identifiers following the :pep:`440` specifica 1. Change log and documentation ------------------------------- -Each PR should include updates to the change log. +Each PR should include updates to the change log as scriv_ fragments (see :ref:`dev-change-log`). +When a release is being made, collect these fragments into the change log by running: + +.. code-block:: sh + + scriv collect --version "X.Y.Z" + If the change log or documentation needs additional updates, now is the time to make those changes through the regular branch-and-PR development method against the ``main`` branch. -In particular, replace the "Unreleased" section headline with the semantic version and date. -See :ref:`dev-change-log` in the *Developer guide* for details. +Each PR should have already created scriv_ change log fragments. 2. Tag the release ------------------ diff --git a/pyproject.toml b/pyproject.toml index aa80ac8..0e629fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,31 +97,6 @@ exclude = ''' # Use single-quoted strings so TOML treats the string like a Python r-string # Multi-line strings are implicitly treated by black as regular expressions -[tool.pydocstyle] -# Reference: http://www.pydocstyle.org/en/stable/error_codes.html -convention = "numpy" -add_select = [ - "D212", # Multi-line docstring summary should start at the first line -] -add-ignore = [ - "D105", # Missing docstring in magic method - "D102", # Missing docstring in public method (needed for docstring inheritance) - "D100", # Missing docstring in public module - # Below are required to allow multi-line summaries. - "D200", # One-line docstring should fit on one line with quotes - "D205", # 1 blank line required between summary line and description - "D400", # First line should end with a period - # Properties shouldn't be written in imperative mode. This will be fixed - # post 6.1.1, see https://github.com/PyCQA/pydocstyle/pull/546 - "D401", -] - -[tool.isort] -profile = "black" -line_length = 79 -known_first_party = "kafkit" -skip = ["docs/conf.py"] - [tool.pytest] [tool.pytest.ini_options] @@ -140,6 +115,93 @@ warn_redundant_casts = true warn_unreachable = true warn_unused_ignores = true +# The rule used with Ruff configuration is to disable every lint that has +# legitimate exceptions that are not dodgy code, rather than cluttering code +# with noqa markers. This is therefore a reiatively relaxed configuration that +# errs on the side of disabling legitimate lints. +# +# Reference for settings: https://beta.ruff.rs/docs/settings/ +# Reference for rules: https://beta.ruff.rs/docs/rules/ +[tool.ruff] +exclude = [ + "docs/**", +] +line-length = 79 +ignore = [ + "ANN101", # self should not have a type annotation + "ANN102", # cls should not have a type annotation + "ANN401", # sometimes Any is the right type + "ARG001", # unused function arguments are often legitimate + "ARG002", # unused method arguments are often legitimate + "ARG005", # unused lambda arguments are often legitimate + "BLE001", # we want to catch and report Exception in background tasks + "C414", # nested sorted is how you sort by multiple keys with reverse + "COM812", # omitting trailing commas allows black autoreformatting + "D102", # sometimes we use docstring inheritence + "D104", # don't see the point of documenting every package + "D105", # our style doesn't require docstrings for magic methods + "D106", # Pydantic uses a nested Config class that doesn't warrant docs + "D205", # Allow multi-line summary sentences + "EM101", # justification (duplicate string in traceback) is silly + "EM102", # justification (duplicate string in traceback) is silly + "FBT003", # positional booleans are normal for Pydantic field defaults + "G004", # forbidding logging f-strings is appealing, but not our style + "RET505", # disagree that omitting else always makes code more readable + "PLR0913", # factory pattern uses constructors with many arguments + "PLR2004", # too aggressive about magic values + "S105", # good idea but too many false positives on non-passwords + "S106", # good idea but too many false positives on non-passwords + "SIM102", # sometimes the formatting of nested if statements is clearer + "SIM116", # allow if-else-if-else chains + "SIM117", # sometimes nested with contexts are clearer + "TCH001", # we decided to not maintain separate TYPE_CHECKING blocks + "TCH002", # we decided to not maintain separate TYPE_CHECKING blocks + "TCH003", # we decided to not maintain separate TYPE_CHECKING blocks + "TID252", # if we're going to use relative imports, use them always + "TRY003", # good general advice but lint is way too aggressive +] +select = ["ALL"] +target-version = "py310" + +[tool.ruff.per-file-ignores] +"tests/**" = [ + "D103", # tests don't need docstrings + "PLR0915", # tests are allowed to be long, sometimes that's convenient + "PT012", # way too aggressive about limiting pytest.raises blocks + "S101", # tests should use assert + "SLF001", # tests are allowed to access private members + "T201", # Print is ok in tests +] + +[tool.ruff.isort] +known-first-party = ["kafkit", "tests"] +split-on-trailing-comma = false + +# These are too useful as attributes or methods to allow the conflict with the +# built-in to rule out their use. +[tool.ruff.flake8-builtins] +builtins-ignorelist = [ + "all", + "any", + "help", + "id", + "list", + "type", +] + +[tool.ruff.flake8-pytest-style] +fixture-parentheses = false +mark-parentheses = false + +[tool.ruff.pep8-naming] +classmethod-decorators = [ + "pydantic.root_validator", + "pydantic.validator", +] + +[tool.ruff.pydocstyle] +convention = "numpy" + [tool.scriv] categories = [ "Backwards-incompatible changes", diff --git a/src/kafkit/__init__.py b/src/kafkit/__init__.py index e80343c..2542d50 100644 --- a/src/kafkit/__init__.py +++ b/src/kafkit/__init__.py @@ -1,10 +1,10 @@ -"""Kafkit helps you write Kafka producers and consumers in Python with asyncio. +"""Kafkit helps you write Kafka producers and consumers in Python +with asyncio. """ __all__ = ["__version__", "version_info"] from importlib.metadata import PackageNotFoundError, version -from typing import List __version__: str """The version string of Kafkit (PEP 440 / SemVer compatible).""" @@ -15,7 +15,7 @@ # package is not installed __version__ = "0.0.0" -version_info: List[str] = __version__.split(".") +version_info: list[str] = __version__.split(".") """The decomposed version, split across "``.``." Use this for version comparison. diff --git a/src/kafkit/fastapi/dependencies/pydanticschemamanager.py b/src/kafkit/fastapi/dependencies/pydanticschemamanager.py index 2c0ec1e..c50cf13 100644 --- a/src/kafkit/fastapi/dependencies/pydanticschemamanager.py +++ b/src/kafkit/fastapi/dependencies/pydanticschemamanager.py @@ -3,7 +3,6 @@ """ from collections.abc import Iterable -from typing import Type from dataclasses_avroschema.avrodantic import AvroBaseModel from httpx import AsyncClient @@ -30,7 +29,7 @@ async def initialize( *, http_client: AsyncClient, registry_url: str, - models: Iterable[Type[AvroBaseModel]], + models: Iterable[type[AvroBaseModel]], suffix: str = "", compatibility: str = "FORWARD", ) -> None: diff --git a/src/kafkit/httputils.py b/src/kafkit/httputils.py index 3f13a06..be94baa 100644 --- a/src/kafkit/httputils.py +++ b/src/kafkit/httputils.py @@ -6,7 +6,7 @@ import cgi import urllib.parse -from typing import Mapping, Optional, Tuple +from collections.abc import Mapping import uritemplate @@ -39,8 +39,8 @@ def format_url(*, host: str, url: str, url_vars: Mapping[str, str]) -> str: def parse_content_type( - content_type: Optional[str], -) -> Tuple[Optional[str], str]: + content_type: str | None, +) -> tuple[str | None, str]: """Tease out the content-type and character encoding. A default character encoding of UTF-8 is used, so the content-type diff --git a/src/kafkit/registry/aiohttp.py b/src/kafkit/registry/aiohttp.py index c8bea79..eec0dc8 100644 --- a/src/kafkit/registry/aiohttp.py +++ b/src/kafkit/registry/aiohttp.py @@ -6,7 +6,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Mapping, Tuple +from collections.abc import Mapping +from typing import TYPE_CHECKING from kafkit.registry import sansio @@ -33,7 +34,7 @@ def __init__(self, *, session: ClientSession, url: str) -> None: async def _request( self, method: str, url: str, headers: Mapping[str, str], body: bytes - ) -> Tuple[int, Mapping[str, str], bytes]: + ) -> tuple[int, Mapping[str, str], bytes]: async with self._session.request( method, url, headers=headers, data=body ) as response: diff --git a/src/kafkit/registry/errors.py b/src/kafkit/registry/errors.py index a538477..eadd1ec 100644 --- a/src/kafkit/registry/errors.py +++ b/src/kafkit/registry/errors.py @@ -9,7 +9,7 @@ "UnmanagedSchemaError", ] -from typing import Any, Optional +from typing import Any class RegistryError(Exception): @@ -37,8 +37,8 @@ def __init__( self, status_code: int, *args: Any, - error_code: Optional[int] = None, - message: Optional[str] = None, + error_code: int | None = None, + message: str | None = None, ) -> None: self.status_code = status_code self.error_code = error_code diff --git a/src/kafkit/registry/httpx.py b/src/kafkit/registry/httpx.py index bc1d2c9..1e198f0 100644 --- a/src/kafkit/registry/httpx.py +++ b/src/kafkit/registry/httpx.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import Mapping, Tuple +from collections.abc import Mapping from httpx import AsyncClient @@ -30,7 +30,7 @@ def __init__(self, *, http_client: AsyncClient, url: str) -> None: async def _request( self, method: str, url: str, headers: Mapping[str, str], body: bytes - ) -> Tuple[int, Mapping[str, str], bytes]: + ) -> tuple[int, Mapping[str, str], bytes]: response = await self._client.request( method, url, headers=headers, content=body ) diff --git a/src/kafkit/registry/manager/_pydantic.py b/src/kafkit/registry/manager/_pydantic.py index 0afd188..752660c 100644 --- a/src/kafkit/registry/manager/_pydantic.py +++ b/src/kafkit/registry/manager/_pydantic.py @@ -5,8 +5,9 @@ from __future__ import annotations import logging +from collections.abc import Iterable from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Iterable, Optional, Type +from typing import TYPE_CHECKING, Any from dataclasses_avroschema.avrodantic import AvroBaseModel @@ -28,7 +29,7 @@ class CachedSchema: schema: dict[str, Any] """The Avro schema derived from the model.""" - model: Type[AvroBaseModel] + model: type[AvroBaseModel] """The Pydantic model.""" @@ -66,8 +67,8 @@ def __init__(self, *, registry: RegistryApi, suffix: str = "") -> None: async def register_models( self, - models: Iterable[Type[AvroBaseModel]], - compatibility: Optional[str] = None, + models: Iterable[type[AvroBaseModel]], + compatibility: str | None = None, ) -> None: """Register the models with the registry. @@ -80,7 +81,7 @@ async def register_models( await self.register_model(model, compatibility=compatibility) async def register_model( - self, model: Type[AvroBaseModel], compatibility: Optional[str] = None + self, model: type[AvroBaseModel], compatibility: str | None = None ) -> None: """Register the model with the registry. @@ -156,7 +157,7 @@ async def deserialize(self, data: bytes) -> AvroBaseModel: return cached_model.model.parse_obj(message_info.message) def _cache_model( - self, model: AvroBaseModel | Type[AvroBaseModel] + self, model: AvroBaseModel | type[AvroBaseModel] ) -> CachedSchema: schema_fqn = self._get_model_fqn(model) avro_schema = model.avro_schema_to_python() @@ -171,18 +172,18 @@ def _cache_model( return self._models[schema_fqn] def _get_model_fqn( - self, model: AvroBaseModel | Type[AvroBaseModel] + self, model: AvroBaseModel | type[AvroBaseModel] ) -> str: # Mypy can't detect the Meta class on the model, so we have to ignore # those lines. try: - name = model.Meta.schema_name # type: ignore + name = model.Meta.schema_name # type: ignore [union-attr] except AttributeError: name = model.__class__.__name__ try: - namespace = model.Meta.namespace # type: ignore + namespace = model.Meta.namespace # type: ignore [union-attr] except AttributeError: namespace = None diff --git a/src/kafkit/registry/manager/_recordname.py b/src/kafkit/registry/manager/_recordname.py index 3013bb1..4ffb394 100644 --- a/src/kafkit/registry/manager/_recordname.py +++ b/src/kafkit/registry/manager/_recordname.py @@ -7,7 +7,7 @@ import json import logging from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any from kafkit.registry.serializer import PolySerializer @@ -83,7 +83,7 @@ def __init__( self._logger = logging.getLogger(__name__) self._serializer = PolySerializer(registry=self._registry) - self.schemas: Dict[str, Any] = {} + self.schemas: dict[str, Any] = {} self._load_schemas() @@ -102,7 +102,7 @@ def _load_schemas(self) -> None: self.schemas[fqn] = schema async def register_schemas( - self, *, compatibility: Optional[str] = None + self, *, compatibility: str | None = None ) -> None: """Register all local schemas with the Confluent Schema Registry. @@ -176,8 +176,11 @@ async def serialize(self, *, data: Any, name: str) -> bytes: schema = self.schemas[name] - encoded_message = await self._serializer.serialize( - data, schema=schema, subject=name - ) - - return encoded_message + try: + return await self._serializer.serialize( + data, schema=schema, subject=name + ) + except ValueError as e: + raise ValueError( + f"Cannot serialize data with schema {name}" + ) from e diff --git a/src/kafkit/registry/sansio.py b/src/kafkit/registry/sansio.py index 979465c..2855899 100644 --- a/src/kafkit/registry/sansio.py +++ b/src/kafkit/registry/sansio.py @@ -9,11 +9,13 @@ from __future__ import annotations import abc +import contextlib import copy import json import logging +from collections.abc import Mapping from enum import Enum -from typing import Any, Dict, Mapping, Optional, Tuple, Union, overload +from typing import Any, overload import fastavro @@ -37,7 +39,7 @@ ] -def make_headers() -> Dict[str, str]: +def make_headers() -> dict[str, str]: """Make HTTP headers for the Confluent Schema Registry. Returns @@ -46,8 +48,7 @@ def make_headers() -> Dict[str, str]: A dictionary of HTTP headers for a Confluent Schema Registry request. All keys are normalized to lowercase for consistency. """ - headers = {"accept": "application/vnd.schemaregistry.v1+json"} - return headers + return {"accept": "application/vnd.schemaregistry.v1+json"} def decipher_response( @@ -72,17 +73,17 @@ def decipher_response( raise RegistryBrokenError( status_code=status_code, error_code=error_code, message=message ) - elif status_code >= 400: + if status_code >= 400: raise RegistryBadRequestError( status_code=status_code, error_code=error_code, message=message ) - elif status_code >= 300: + if status_code >= 300: raise RegistryRedirectionError(status_code=status_code) - else: - raise RegistryHttpError(status_code=status_code) + + raise RegistryHttpError(status_code=status_code) -def decode_body(content_type: Optional[str], body: bytes) -> Any: +def decode_body(content_type: str | None, body: bytes) -> Any: """Decode an HTTP body based on the specified content type. Parameters @@ -143,7 +144,7 @@ def subject_cache(self) -> SubjectCache: @abc.abstractmethod async def _request( self, method: str, url: str, headers: Mapping[str, str], body: bytes - ) -> Tuple[int, Mapping[str, str], bytes]: + ) -> tuple[int, Mapping[str, str], bytes]: """Make an HTTP request. Parameters @@ -180,11 +181,10 @@ async def _make_request( response = await self._request( method, expanded_url, request_headers, body ) - response_data = decipher_response(*response) - return response_data + return decipher_response(*response) async def get( - self, url: str, url_vars: Optional[Mapping[str, str]] = None + self, url: str, url_vars: Mapping[str, str] | None = None ) -> Any: """Send an HTTP GET request. @@ -217,13 +217,12 @@ async def get( """ if url_vars is None: url_vars = {} - data = await self._make_request("GET", url, url_vars, b"") - return data + return await self._make_request("GET", url, url_vars, b"") async def post( self, url: str, - url_vars: Optional[Mapping[str, str]] = None, + url_vars: Mapping[str, str] | None = None, *, data: Any, ) -> Any: @@ -260,13 +259,12 @@ async def post( """ if url_vars is None: url_vars = {} - data = await self._make_request("POST", url, url_vars, data) - return data + return await self._make_request("POST", url, url_vars, data) async def patch( self, url: str, - url_vars: Optional[Mapping[Any, Any]] = None, + url_vars: Mapping[Any, Any] | None = None, *, data: Any, ) -> Any: @@ -303,13 +301,12 @@ async def patch( """ if url_vars is None: url_vars = {} - data = await self._make_request("PATCH", url, url_vars, data) - return data + return await self._make_request("PATCH", url, url_vars, data) async def put( self, url: str, - url_vars: Optional[Mapping[str, str]] = None, + url_vars: Mapping[str, str] | None = None, data: Any = b"", ) -> Any: """Send an HTTP PUT request. @@ -345,13 +342,12 @@ async def put( """ if url_vars is None: url_vars = {} - data = await self._make_request("PUT", url, url_vars, data) - return data + return await self._make_request("PUT", url, url_vars, data) async def delete( self, url: str, - url_vars: Optional[Mapping[str, str]] = None, + url_vars: Mapping[str, str] | None = None, *, data: Any = b"", ) -> Any: @@ -386,8 +382,7 @@ async def delete( """ if url_vars is None: url_vars = {} - data = await self._make_request("DELETE", url, url_vars, data) - return data + return await self._make_request("DELETE", url, url_vars, data) @staticmethod def _prep_schema(schema: Mapping[str, Any]) -> str: @@ -405,8 +400,8 @@ def _prep_schema(schema: Mapping[str, Any]) -> str: async def register_schema( self, schema: Mapping[str, Any], - subject: Optional[str] = None, - compatibility: Optional[str] = None, + subject: str | None = None, + compatibility: str | None = None, ) -> int: """Register a schema or get the ID of an existing schema. @@ -443,19 +438,18 @@ async def register_schema( # look in cache first try: - schema_id = self.schema_cache[schema] - return schema_id + return self.schema_cache[schema] except KeyError: pass if subject is None: try: subject = schema["name"] - except (KeyError, TypeError): + except (KeyError, TypeError) as e: raise RuntimeError( "Cannot get a subject name from a 'name' " f"key in the schema: {schema!r}" - ) + ) from e result = await self.post( "/subjects{/subject}/versions", @@ -477,11 +471,11 @@ async def set_subject_compatibility( # Validate compatibility setting try: CompatibilityType[compatibility] - except KeyError: + except KeyError as e: raise ValueError( f"Compatibility setting {compatibility!r} is not in the " f"allowed set: {[v.value for v in CompatibilityType]}" - ) + ) from e try: subject_config = await self.get( @@ -518,7 +512,7 @@ async def set_subject_compatibility( compatibility, ) - async def get_schema_by_id(self, schema_id: int) -> Dict[str, Any]: + async def get_schema_by_id(self, schema_id: int) -> dict[str, Any]: """Get a schema from the registry given its ID. Wraps ``GET /schemas/ids/{int: id}``. @@ -546,8 +540,7 @@ async def get_schema_by_id(self, schema_id: int) -> Dict[str, Any]: """ # Look in the cache first try: - schema = self.schema_cache[schema_id] - return schema + return self.schema_cache[schema_id] except KeyError: pass @@ -562,8 +555,8 @@ async def get_schema_by_id(self, schema_id: int) -> Dict[str, Any]: return schema async def get_schema_by_subject( - self, subject: str, version: Union[str, int] = "latest" - ) -> Dict[str, Any]: + self, subject: str, version: str | int = "latest" + ) -> dict[str, Any]: """Get a schema for a subject in the registry. Wraps ``GET /subjects/(string: subject)/versions/(versionId: version)`` @@ -605,12 +598,13 @@ async def get_schema_by_subject( calls this method, and you want to make use of caching, replace ``"latest"`` versions with integer versions once they're known. """ - try: - # The SubjectCache.get method is designed to have the same return - # type as this method. - return self.subject_cache.get(subject, version) - except ValueError: - pass + if isinstance(version, int): + try: + # The SubjectCache.get method is designed to have the same + # return type as this method. + return self.subject_cache.get(subject, version) + except ValueError: + pass result = await self.get( "/subjects{/subject}/versions{/version}", @@ -619,16 +613,14 @@ async def get_schema_by_subject( schema = fastavro.parse_schema(json.loads(result["schema"])) - try: + with contextlib.suppress(TypeError): + # Can't cache versions like "latest" self.subject_cache.insert( result["subject"], result["version"], schema_id=result["id"], schema=schema, ) - except TypeError: - # Can't cache versions like "latest" - pass return { "id": result["id"], @@ -643,20 +635,16 @@ class MockRegistryApi(RegistryApi): network operations and provides attributes for introspection. """ - DEFAULT_HEADERS = { - "content-type": "application/vnd.schemaregistry.v1+json" - } - def __init__( self, url: str = "http://registry:8081", status_code: int = 200, - headers: Optional[Mapping[str, str]] = None, + headers: Mapping[str, str] | None = None, body: Any = b"", ) -> None: super().__init__(url=url) self.response_code = status_code - self.response_headers = headers if headers else self.DEFAULT_HEADERS + self.response_headers = headers if headers else self._default_headers self.response_body = body async def _request( @@ -669,6 +657,10 @@ async def _request( response_headers = copy.deepcopy(self.response_headers) return self.response_code, response_headers, self.response_body + @property + def _default_headers(self) -> dict[str, str]: + return {"content-type": "application/vnd.schemaregistry.v1+json"} + class SchemaCache: """A cache of schemas that maintains a mapping of schemas and their IDs @@ -681,8 +673,8 @@ class SchemaCache: """ def __init__(self) -> None: - self._id_to_schema: Dict[int, str] = {} - self._schema_to_id: Dict[str, int] = {} + self._id_to_schema: dict[int, str] = {} + self._schema_to_id: dict[str, int] = {} def insert(self, schema: Mapping[str, Any], schema_id: int) -> None: """Insert a schema into the cache. @@ -696,22 +688,22 @@ def insert(self, schema: Mapping[str, Any], schema_id: int) -> None: """ # ensure the cached schemas are always parsed, and then serialize # so it's hashable - serialized_schema = SchemaCache._serialize_schema(schema) + serialized_schema = self._serialize_schema(schema) self._id_to_schema[schema_id] = serialized_schema self._schema_to_id[serialized_schema] = schema_id @overload - def __getitem__(self, key: int) -> Dict[str, Any]: + def __getitem__(self, key: int) -> dict[str, Any]: ... - @overload # noqa: F811 remove for pyflakes 2.2.x - def __getitem__(self, key: Mapping[str, Any]) -> int: # noqa: F811 + @overload + def __getitem__(self, key: Mapping[str, Any]) -> int: ... - def __getitem__( # noqa: F811 remove for pyflakes 2.2.x - self, key: Union[Mapping[str, Any], int] - ) -> Union[Dict[str, Any], int]: + def __getitem__( + self, key: Mapping[str, Any] | int + ) -> dict[str, Any] | int: if isinstance(key, int): return json.loads(self._id_to_schema[key]) else: @@ -719,16 +711,16 @@ def __getitem__( # noqa: F811 remove for pyflakes 2.2.x # Always ensure the schema is parsed schema = copy.deepcopy(key) try: - serialized_schema = SchemaCache._serialize_schema(schema) - except Exception: + serialized_schema = self._serialize_schema(schema) + except Exception as e: # If the schema couldn't be parsed, its not going to be a # valid key anyhow. raise KeyError( f"Key or schema not in the SchemaCache: {key!r}" - ) + ) from e return self._schema_to_id[serialized_schema] - def __contains__(self, key: Union[int, Mapping[str, Any]]) -> bool: + def __contains__(self, key: int | Mapping[str, Any]) -> bool: try: self[key] except KeyError: @@ -764,7 +756,7 @@ class SubjectCache: def __init__(self, schema_cache: SchemaCache) -> None: self.schema_cache = schema_cache - self._subject_to_id: Dict[Tuple[str, int], int] = {} + self._subject_to_id: dict[tuple[str, int], int] = {} def get_id(self, subject: str, version: int) -> int: """Get the schema ID of a subject version. @@ -794,9 +786,12 @@ def get_id(self, subject: str, version: int) -> int: try: return self._subject_to_id[(subject, version)] except KeyError as e: - raise ValueError from e + raise ValueError( + f"Schema with subject {subject!r} version {version!r} " + "not cached." + ) from e - def get_schema(self, subject: str, version: int) -> Dict[str, Any]: + def get_schema(self, subject: str, version: int) -> dict[str, Any]: """Get the schema of a subject version. Parameters @@ -823,12 +818,14 @@ def get_schema(self, subject: str, version: int) -> Dict[str, Any]: get """ try: - schema = self.schema_cache[self.get_id(subject, version)] - return schema + return self.schema_cache[self.get_id(subject, version)] except KeyError as e: - raise ValueError from e + raise ValueError( + f"Schema with subject {subject!r} version {version!r} " + "not cached." + ) from e - def get(self, subject: str, version: Union[int, str]) -> Dict[str, Any]: + def get(self, subject: str, version: int | str) -> dict[str, Any]: """Get the full set of schema and ID information for a subject version. Parameters @@ -837,7 +834,7 @@ def get(self, subject: str, version: Union[int, str]) -> Dict[str, Any]: The name of the subject. version : `int` The version number of the schema in the subject. If version is - given as a string (``"latest"``), a `ValueError` is raised. + given as a string (``"latest"``), a `TypeError` is raised. Returns ------- @@ -859,6 +856,8 @@ def get(self, subject: str, version: Union[int, str]) -> Dict[str, Any]: ------ ValueError Raised if the schema does not exist in the cache. + TypeError + Raised if the version is a string, like "latest". See Also -------- @@ -866,12 +865,16 @@ def get(self, subject: str, version: Union[int, str]) -> Dict[str, Any]: get_schema """ if not isinstance(version, int): - raise ValueError("version must be an int, got {}".format(version)) + msg = f"version must be an int, got {version}" + raise TypeError(msg) try: schema_id = self.get_id(subject, version) schema = self.schema_cache[schema_id] except KeyError as e: - raise ValueError from e + raise ValueError( + f"Schema with subject {subject!r} version {version!r} " + "not cached." + ) from e # Important: this return type maches RegistryApi.get_schema_by_subject # If this is changed, make sure get_schema_by_subject is also changed. @@ -886,8 +889,8 @@ def insert( self, subject: str, version: int, - schema_id: Optional[int] = None, - schema: Optional[Mapping[str, Any]] = None, + schema_id: int | None = None, + schema: Mapping[str, Any] | None = None, ) -> None: """Insert a subject version into the cache. @@ -955,11 +958,11 @@ def insert( "Provide either a schema_id or schema argument (or both)." ) - def __contains__(self, key: Tuple[str, int]) -> bool: + def __contains__(self, key: tuple[str, int]) -> bool: return key in self._subject_to_id -class CompatibilityType(str, Enum): +class CompatibilityType(Enum): """Compatibility settings available for the Confluent Schema Registry, as an Enum. diff --git a/src/kafkit/registry/serializer.py b/src/kafkit/registry/serializer.py index 0547e03..4c19760 100644 --- a/src/kafkit/registry/serializer.py +++ b/src/kafkit/registry/serializer.py @@ -7,7 +7,7 @@ import struct from dataclasses import dataclass from io import BytesIO -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple +from typing import TYPE_CHECKING, Any import fastavro @@ -71,7 +71,7 @@ class Serializer: https://aiokafka.readthedocs.io/en/stable/examples/serialize_and_compress.html """ - def __init__(self, *, schema: Dict[str, Any], schema_id: int) -> None: + def __init__(self, *, schema: dict[str, Any], schema_id: int) -> None: self.schema = fastavro.parse_schema(schema) self.id = schema_id @@ -80,8 +80,8 @@ async def register( cls, *, registry: RegistryApi, - schema: Dict[str, Any], - subject: Optional[str] = None, + schema: dict[str, Any], + subject: str | None = None, ) -> Serializer: """Create a serializer ensuring that the schema is registered with the schema registry. @@ -147,9 +147,9 @@ def __init__(self, *, registry: RegistryApi) -> None: async def serialize( self, data: Any, - schema: Optional[Dict[str, Any]] = None, - schema_id: Optional[int] = None, - subject: Optional[str] = None, + schema: dict[str, Any] | None = None, + schema_id: int | None = None, + subject: str | None = None, ) -> bytes: """Serialize data given a schema. @@ -194,7 +194,7 @@ async def serialize( def _make_message( - *, schema_id: int, schema: Dict[str, Any], data: Any + *, schema_id: int, schema: dict[str, Any], data: Any ) -> bytes: """Make a message in the Confluent Wire Format.""" message_fh = BytesIO() @@ -348,7 +348,7 @@ def pack_wire_format_prefix(schema_id: int) -> bytes: return struct.pack(">bI", 0, schema_id) -def unpack_wire_format_data(data: bytes) -> Tuple[int, bytes]: +def unpack_wire_format_data(data: bytes) -> tuple[int, bytes]: """Unpackage the bytes of a Confluent Wire Format message to get the schema ID and message body. @@ -373,7 +373,7 @@ def unpack_wire_format_data(data: bytes) -> Tuple[int, bytes]: """ if len(data) < 5: raise RuntimeError( - f"Data is too short, length is {len(data)} " "bytes. Must be >= 5." + f"Data is too short, length is {len(data)} bytes. Must be >= 5." ) prefix = data[:5] diff --git a/src/kafkit/registry/utils.py b/src/kafkit/registry/utils.py index b1f20e2..09b42a2 100644 --- a/src/kafkit/registry/utils.py +++ b/src/kafkit/registry/utils.py @@ -2,7 +2,8 @@ from __future__ import annotations -from typing import Any, Mapping +from collections.abc import Mapping +from typing import Any __all__ = ["get_avro_fqn"] @@ -34,5 +35,4 @@ def get_avro_fqn(schema: Mapping[str, Any]) -> str: fqn = ".".join((schema["namespace"], schema["name"])) else: fqn = schema["name"] - assert isinstance(fqn, str) return fqn diff --git a/src/kafkit/settings.py b/src/kafkit/settings.py index a7b89af..df19751 100644 --- a/src/kafkit/settings.py +++ b/src/kafkit/settings.py @@ -17,7 +17,7 @@ ] -class KafkaSecurityProtocol(str, Enum): +class KafkaSecurityProtocol(Enum): """Kafka security protocols understood by aiokafka.""" PLAINTEXT = "PLAINTEXT" @@ -27,7 +27,7 @@ class KafkaSecurityProtocol(str, Enum): """TLS-encrypted connection.""" -class KafkaSaslMechanism(str, Enum): +class KafkaSaslMechanism(Enum): """Kafka SASL mechanisms understood by aiokafka.""" PLAIN = "PLAIN" @@ -171,11 +171,6 @@ def ssl_context(self) -> SSLContext | None: ): return None - # For type checking - assert self.client_cert_path is not None - assert self.cluster_ca_path is not None - assert self.client_key_path is not None - client_cert_path = Path(self.client_cert_path) if self.client_ca_path is not None: @@ -188,10 +183,7 @@ def ssl_context(self) -> SSLContext | None: ) client_ca = Path(self.client_ca_path).read_text() client_cert = Path(self.client_cert_path).read_text() - if client_ca.endswith("\n"): - sep = "" - else: - sep = "\n" + sep = "" if client_ca.endswith("\n") else "\n" new_client_cert = sep.join([client_cert, client_ca]) new_client_cert_path = Path(self.cert_temp_dir) / "client.crt" new_client_cert_path.write_text(new_client_cert) diff --git a/tests/httputils_test.py b/tests/httputils_test.py index 66c8cca..24c0214 100644 --- a/tests/httputils_test.py +++ b/tests/httputils_test.py @@ -1,6 +1,5 @@ """Tests for the kafkit.utils module.""" -from typing import Dict import pytest @@ -8,7 +7,7 @@ @pytest.mark.parametrize( - "host,url,url_vars,expected", + ("host", "url", "url_vars", "expected"), [ ( "http://confluent-kafka-cp-schema-registry:8081", @@ -37,7 +36,7 @@ ], ) def test_format_url( - host: str, url: str, url_vars: Dict[str, str], expected: str + host: str, url: str, url_vars: dict[str, str], expected: str ) -> None: """Test `kafkit.httputils.format_url`.""" assert expected == format_url(host=host, url=url, url_vars=url_vars) diff --git a/tests/pydantic_schema_manager_test.py b/tests/pydantic_schema_manager_test.py index 2f919f9..56f9e6a 100644 --- a/tests/pydantic_schema_manager_test.py +++ b/tests/pydantic_schema_manager_test.py @@ -5,7 +5,6 @@ import os from datetime import datetime, timezone from enum import Enum -from typing import Optional import pytest from dataclasses_avroschema.avrodantic import AvroBaseModel @@ -22,14 +21,14 @@ def current_datetime() -> datetime: return datetime.now(tz=timezone.utc) -class SlackMessageType(str, Enum): +class SlackMessageType(Enum): """The type of Slack message.""" app_mention = "app_mention" message = "message" -class SlackChannelType(str, Enum): +class SlackChannelType(Enum): """The type of Slack channel.""" channel = "channel" # public channel @@ -54,7 +53,7 @@ class SquarebotMessage(AvroBaseModel): description="The type of channel (public, direct im, etc..)" ) - user: Optional[str] = Field( + user: str | None = Field( description="The ID of the user that sent the message (eg U061F7AUR)." ) diff --git a/tests/registry_manager_test.py b/tests/registry_manager_test.py index 5be882b..154ad4d 100644 --- a/tests/registry_manager_test.py +++ b/tests/registry_manager_test.py @@ -42,5 +42,8 @@ async def test_recordnameschemamanager() -> None: assert isinstance(data_b, bytes) # Sanity check that you can't serialize with the wrong schema! - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match=("Cannot serialize data with schema kafkit.a"), + ): await manager.serialize(data=topic_b_message, name="kafkit.a") diff --git a/tests/registry_sansio_test.py b/tests/registry_sansio_test.py index 3e3557b..c63447c 100644 --- a/tests/registry_sansio_test.py +++ b/tests/registry_sansio_test.py @@ -344,8 +344,8 @@ def test_schema_cache() -> None: with pytest.raises(KeyError): cache[0] with pytest.raises(KeyError): - schemaX = {"type": "unknown"} - cache[schemaX] + schema_x = {"type": "unknown"} + cache[schema_x] def test_subject_cache() -> None: @@ -394,11 +394,20 @@ def test_subject_cache() -> None: assert cache.get_schema("schema2", 32)["name"] == "test-schemas.schema2" # Test inserting a subject that does not have a pre-cached schema - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match=r"^Provide either a", + ): cache.insert("schema3", 13) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="^Trying to cache the schema ID for subject 'schema3'", + ): cache.insert("schema3", 13, schema_id=3) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="^Trying to cache the schema ID for subject 'schema3'", + ): cache.insert("schema3", 13, schema=schema3) cache.insert("schema3", 13, schema=schema3, schema_id=3) assert ("schema3", 13) in cache @@ -406,13 +415,24 @@ def test_subject_cache() -> None: assert cache.get_schema("schema3", 13)["name"] == "test-schemas.schema3" # Test getting a non-existent subject or version - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="Schema with subject 'schema3' version 25 not cached.", + ): cache.get_id("schema3", 25) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="Schema with subject 'schema18' version 25 not cached.", + ): cache.get_schema("schema18", 25) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="Schema with subject 'schema18' version 15 not cached.", + ): cache.get("schema18", 15) # Test caching 'latest' with pytest.raises(TypeError): - cache.insert("mysubject", "latest", schema_id=42) # type: ignore + cache.insert( + "mysubject", "latest", schema_id=42 # type: ignore[arg-type] + ) diff --git a/tox.ini b/tox.ini index 19aaa3d..52e7591 100644 --- a/tox.ini +++ b/tox.ini @@ -13,15 +13,15 @@ extras = pydantic aiokafka allowlist_externals = - docker-compose + docker setenv = KAFKA_BROKER_URL=localhost:9092 SCHEMA_REGISTRY_URL=http://localhost:8081 commands = - docker-compose up -d + docker compose up -d holdup -t 60 -T 5 -i 1 -n http://localhost:8081/subjects coverage run -m pytest {posargs} - docker-compose down + docker compose down [testenv:coverage-report] description = Compile coverage from each test run.