Skip to content

Use msgspec instead of pydantic #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ repos:
args: [--markdown-linebreak-ext=md]

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.9
rev: v0.9.9
hooks:
- id: ruff
- id: ruff-format

- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.384
rev: v1.1.396
hooks:
- id: pyright
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
# ruff: noqa: T201
from pydantic import BaseModel
import msgspec

from contiguity import Base


# Create a Pydantic model for the item.
class MyItem(BaseModel):
# Create a msgspec struct for the item.
class MyItem(msgspec.Struct):
key: str # Make sure to include the key field.
name: str
age: int
interests: list[str] = []


# Create a Base instance.
# Static type checking will work with the Pydantic model.
# Static type checking will work with the msgspec struct.
db = Base("members", item_type=MyItem)

# Put an item with a specific key.
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ classifiers = [
]
dependencies = [
"httpx>=0.27.2",
"msgspec>=0.19.0",
"phonenumbers>=8.13.47,<9.0.0",
"pydantic>=2.9.0,<3.0.0",
"typing-extensions>=4.12.2,<5.0.0",
]

Expand Down Expand Up @@ -67,6 +67,9 @@ target-version = "py39"
select = ["ALL"]
ignore = ["A", "D", "T201"]

[tool.ruff.lint.per-file-ignores]
"tests/**" = ["S101"]

[tool.pyright]
venvPath = "."
venv = ".venv"
Expand Down
14 changes: 7 additions & 7 deletions src/contiguity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,20 @@ def login(token: str, /, *, debug: bool = False) -> Contiguity:


__all__ = (
"AsyncBase",
"Contiguity",
"Send",
"Verify",
"EmailAnalytics",
"Quota",
"OTP",
"Template",
"AsyncBase",
"Base",
"BaseItem",
"Contiguity",
"EmailAnalytics",
"InvalidKeyError",
"ItemConflictError",
"ItemNotFoundError",
"QueryResponse",
"Quota",
"Send",
"Template",
"Verify",
"login",
)
__version__ = "2.0.0"
4 changes: 1 addition & 3 deletions src/contiguity/_auth.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from __future__ import annotations

import os


def _get_env_var(var_name: str, friendly_name: str | None = None) -> str:
def _get_env_var(var_name: str, friendly_name: "str | None" = None) -> str:
value = os.getenv(var_name, "")
if not value:
msg = f"no {friendly_name or var_name} provided"
Expand Down
10 changes: 4 additions & 6 deletions src/contiguity/_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from __future__ import annotations

from httpx import AsyncClient as HttpxAsyncClient
from httpx import Client as HttpxClient

Expand All @@ -12,10 +10,10 @@ class ApiError(Exception):

class ApiClient(HttpxClient):
def __init__(
self: ApiClient,
self,
*,
base_url: str = "https://api.contiguity.co",
api_key: str | None = None,
api_key: "str | None" = None,
timeout: int = 5,
) -> None:
if not api_key:
Expand All @@ -33,10 +31,10 @@ def __init__(

class AsyncApiClient(HttpxAsyncClient):
def __init__(
self: AsyncApiClient,
self,
*,
base_url: str = "https://api.contiguity.co",
api_key: str | None = None,
api_key: "str | None" = None,
timeout: int = 5,
) -> None:
if not api_key:
Expand Down
4 changes: 2 additions & 2 deletions src/contiguity/_common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pydantic import BaseModel
import msgspec


class Crumbs(BaseModel):
class Crumbs(msgspec.Struct):
plan: str
quota: int
type: str
Expand Down
57 changes: 26 additions & 31 deletions src/contiguity/base/async_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@
from typing import TYPE_CHECKING, Generic, Literal, overload
from warnings import warn

import msgspec
from httpx import HTTPStatusError
from pydantic import BaseModel, TypeAdapter
from pydantic import JsonValue as DataType
from typing_extensions import deprecated

from contiguity._auth import get_data_key, get_project_id
from contiguity._client import ApiError, AsyncApiClient

from .common import (
UNSET,
DataType,
DefaultItemT,
ItemT,
QueryResponse,
Expand All @@ -33,7 +33,6 @@

if TYPE_CHECKING:
from httpx import Response as HttpxResponse
from typing_extensions import Self


class AsyncBase(Generic[ItemT]):
Expand All @@ -42,7 +41,7 @@ class AsyncBase(Generic[ItemT]):

@overload
def __init__(
self: Self,
self,
name: str,
/,
*,
Expand All @@ -57,7 +56,7 @@ def __init__(
@overload
@deprecated("The `project_key` parameter has been renamed to `data_key`.")
def __init__(
self: Self,
self,
name: str,
/,
*,
Expand All @@ -70,7 +69,7 @@ def __init__(
) -> None: ...

def __init__( # noqa: PLR0913
self: Self,
self,
name: str,
/,
*,
Expand All @@ -80,7 +79,7 @@ def __init__( # noqa: PLR0913
project_id: str | None = None,
host: str | None = None,
api_version: str = "v1",
json_decoder: type[json.JSONDecoder] = json.JSONDecoder, # Only used when item_type is not a Pydantic model.
json_decoder: type[json.JSONDecoder] = json.JSONDecoder, # Only used when item_type is not a msgspec struct.
) -> None:
if not name:
msg = f"invalid Base name '{name}'"
Expand All @@ -102,23 +101,23 @@ def __init__( # noqa: PLR0913

@overload
def _response_as_item_type(
self: Self,
self,
response: HttpxResponse,
/,
*,
sequence: Literal[False] = False,
) -> ItemT: ...
@overload
def _response_as_item_type(
self: Self,
self,
response: HttpxResponse,
/,
*,
sequence: Literal[True] = True,
) -> Sequence[ItemT]: ...

def _response_as_item_type(
self: Self,
self,
response: HttpxResponse,
/,
*,
Expand All @@ -130,12 +129,12 @@ def _response_as_item_type(
raise ApiError(exc.response.text) from exc
if self.item_type:
if sequence:
return TypeAdapter(Sequence[self.item_type]).validate_json(response.content)
return TypeAdapter(self.item_type).validate_json(response.content)
return msgspec.json.decode(response.content, type=Sequence[self.item_type])
return msgspec.json.decode(response.content, type=self.item_type)
return response.json(cls=self.json_decoder)

def _insert_expires_attr(
self: Self,
self,
item: ItemT | Mapping[str, DataType],
expire_in: int | None = None,
expire_at: TimestampType | None = None,
Expand All @@ -144,7 +143,7 @@ def _insert_expires_attr(
msg = "cannot use both expire_in and expire_at"
raise ValueError(msg)

item_dict = item.model_dump() if isinstance(item, BaseModel) else dict(item)
item_dict = msgspec.structs.asdict(item) if isinstance(item, msgspec.Struct) else dict(item)

if not expire_in and not expire_at:
return item_dict
Expand All @@ -160,16 +159,16 @@ def _insert_expires_attr(
return item_dict

@overload
async def get(self: Self, key: str, /) -> ItemT | None: ...
async def get(self, key: str, /) -> ItemT | None: ...

@overload
async def get(self: Self, key: str, /, *, default: ItemT) -> ItemT: ...
async def get(self, key: str, /, *, default: ItemT) -> ItemT: ...

@overload
async def get(self: Self, key: str, /, *, default: DefaultItemT) -> ItemT | DefaultItemT: ...
async def get(self, key: str, /, *, default: DefaultItemT) -> ItemT | DefaultItemT: ...

async def get(
self: Self,
self,
key: str,
/,
*,
Expand All @@ -189,7 +188,7 @@ async def get(

return self._response_as_item_type(response, sequence=False)

async def delete(self: Self, key: str, /) -> None:
async def delete(self, key: str, /) -> None:
"""Delete an item from the Base."""
key = check_key(key)
response = await self._client.delete(f"/items/{key}")
Expand All @@ -199,7 +198,7 @@ async def delete(self: Self, key: str, /) -> None:
raise ApiError(exc.response.text) from exc

async def insert(
self: Self,
self,
item: ItemT,
/,
*,
Expand All @@ -218,7 +217,7 @@ async def insert(
return returned_item[0]

async def put(
self: Self,
self,
*items: ItemT,
expire_in: int | None = None,
expire_at: TimestampType | None = None,
Expand All @@ -239,7 +238,7 @@ async def put(

@deprecated("This method will be removed in a future release. You can pass multiple items to `put`.")
async def put_many(
self: Self,
self,
items: Sequence[ItemT],
/,
*,
Expand All @@ -249,7 +248,7 @@ async def put_many(
return await self.put(*items, expire_in=expire_in, expire_at=expire_at)

async def update(
self: Self,
self,
updates: Mapping[str, DataType | UpdateOperation],
/,
*,
Expand All @@ -273,14 +272,14 @@ async def update(
expire_at=expire_at,
)

response = await self._client.patch(f"/items/{key}", json={"updates": payload.model_dump()})
response = await self._client.patch(f"/items/{key}", json={"updates": msgspec.structs.asdict(payload)})
if response.status_code == HTTPStatus.NOT_FOUND:
raise ItemNotFoundError(key)

return self._response_as_item_type(response, sequence=False)

async def query(
self: Self,
self,
*queries: QueryType,
limit: int = 1000,
last: str | None = None,
Expand All @@ -302,15 +301,11 @@ async def query(
response.raise_for_status()
except HTTPStatusError as exc:
raise ApiError(exc.response.text) from exc
query_response = QueryResponse[ItemT].model_validate_json(response.content)
if self.item_type:
# HACK: Pydantic model_validate_json doesn't validate Sequence[ItemT] properly. # noqa: FIX004
query_response.items = TypeAdapter(Sequence[self.item_type]).validate_python(query_response.items)
return query_response
return msgspec.json.decode(response.content, type=QueryResponse[ItemT])

@deprecated("This method has been renamed to `query` and will be removed in a future release.")
async def fetch(
self: Self,
self,
*queries: QueryType,
limit: int = 1000,
last: str | None = None,
Expand Down
Loading
Loading