From 28b48ae6c6ff6061ca4b3994be1c4259d6a1d9b4 Mon Sep 17 00:00:00 2001 From: mhh Date: Mon, 28 Aug 2023 15:15:22 +0200 Subject: [PATCH 01/17] squash! refactor and improve clients; extract ABC for future use with alternative client implementations --- cache/files/QmAndSoOn | 1 + 1 file changed, 1 insertion(+) create mode 100644 cache/files/QmAndSoOn diff --git a/cache/files/QmAndSoOn b/cache/files/QmAndSoOn new file mode 100644 index 00000000..d9605cba --- /dev/null +++ b/cache/files/QmAndSoOn @@ -0,0 +1 @@ +HELLO \ No newline at end of file From e9bc6abd0fc246f02f1da3dd4bf5de29368c453e Mon Sep 17 00:00:00 2001 From: mhh Date: Tue, 5 Sep 2023 14:45:30 +0200 Subject: [PATCH 02/17] add MessageCache and DomainNode, based on peewee ORM & SQLite --- setup.cfg | 3 + src/aleph/sdk/conf.py | 9 + src/aleph/sdk/node.py | 749 ++++++++++++++++++++++++++++++++++++ tests/unit/conftest.py | 70 ++++ tests/unit/test_node.py | 255 ++++++++++++ tests/unit/test_node_get.py | 231 +++++++++++ 6 files changed, 1317 insertions(+) create mode 100644 src/aleph/sdk/node.py create mode 100644 tests/unit/test_node.py create mode 100644 tests/unit/test_node_get.py diff --git a/setup.cfg b/setup.cfg index e926e128..7d1815e8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -81,6 +81,7 @@ testing = flake8 substrate-interface py-sr25519-bindings + peewee mqtt = aiomqtt<=0.1.3 certifi @@ -106,6 +107,8 @@ ledger = ledgereth==0.9.0 docs = sphinxcontrib-plantuml +cache = + peewee [options.entry_points] # Add here console scripts like: diff --git a/src/aleph/sdk/conf.py b/src/aleph/sdk/conf.py index 885bd05a..cf63cdc0 100644 --- a/src/aleph/sdk/conf.py +++ b/src/aleph/sdk/conf.py @@ -38,6 +38,15 @@ class Settings(BaseSettings): CODE_USES_SQUASHFS: bool = which("mksquashfs") is not None # True if command exists + CACHE_DATABASE_PATH: Path = Field( + default=Path(":memory:"), # can also be :memory: for in-memory caching + description="Path to the cache database", + ) + CACHE_FILES_PATH: Path = Field( + default=Path("cache", "files"), + description="Path to the cache files", + ) + class Config: env_prefix = "ALEPH_" case_sensitive = False diff --git a/src/aleph/sdk/node.py b/src/aleph/sdk/node.py new file mode 100644 index 00000000..a9548e67 --- /dev/null +++ b/src/aleph/sdk/node.py @@ -0,0 +1,749 @@ +import asyncio +import json +import logging +import typing +from datetime import datetime +from functools import partial +from pathlib import Path +from typing import ( + Any, + AsyncIterable, + Coroutine, + Dict, + Generic, + Iterable, + Iterator, + List, + Mapping, + Optional, + Tuple, + Type, + TypeVar, + Union, +) + +from aleph_message import MessagesResponse, parse_message +from aleph_message.models import ( + AlephMessage, + Chain, + ItemHash, + MessageConfirmation, + MessageType, +) +from aleph_message.models.execution.base import Encoding +from aleph_message.status import MessageStatus +from peewee import ( + BooleanField, + CharField, + FloatField, + IntegerField, + Model, + SqliteDatabase, +) +from playhouse.shortcuts import model_to_dict +from playhouse.sqlite_ext import JSONField +from pydantic import BaseModel + +from aleph.sdk import AuthenticatedAlephClient +from aleph.sdk.base import AlephClientBase, AuthenticatedAlephClientBase +from aleph.sdk.conf import settings +from aleph.sdk.exceptions import MessageNotFoundError +from aleph.sdk.models import PostsResponse +from aleph.sdk.types import GenericMessage, StorageEnum + +db = SqliteDatabase(settings.CACHE_DATABASE_PATH) +T = TypeVar("T", bound=BaseModel) + + +class JSONDictEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, BaseModel): + return obj.dict() + return json.JSONEncoder.default(self, obj) + + +pydantic_json_dumps = partial(json.dumps, cls=JSONDictEncoder) + + +class PydanticField(JSONField, Generic[T]): + """ + A field for storing pydantic model types as JSON in a database. Uses json for serialization. + """ + + type: T + + def __init__(self, *args, **kwargs): + self.type = kwargs.pop("type") + super().__init__(*args, **kwargs) + + def db_value(self, value: Optional[T]) -> Optional[str]: + if value is None: + return None + return value.json() + + def python_value(self, value: Optional[str]) -> Optional[T]: + if value is None: + return None + return self.type.parse_raw(value) + + +class MessageModel(Model): + """ + A simple database model for storing AlephMessage objects. + """ + + item_hash = CharField(primary_key=True) + chain = CharField(5) + type = CharField(9) + sender = CharField() + channel = CharField(null=True) + confirmations: PydanticField[MessageConfirmation] = PydanticField( + type=MessageConfirmation, null=True + ) + confirmed = BooleanField(null=True) + signature = CharField(null=True) + size = IntegerField(null=True) + time = FloatField() + item_type = CharField(7) + item_content = CharField(null=True) + hash_type = CharField(6, null=True) + content = JSONField(json_dumps=pydantic_json_dumps) + forgotten_by = CharField(null=True) + tags = JSONField(json_dumps=pydantic_json_dumps, null=True) + key = CharField(null=True) + ref = CharField(null=True) + content_type = CharField(null=True) + + class Meta: + database = db + + +def message_to_model(message: AlephMessage) -> Dict: + return { + "item_hash": str(message.item_hash), + "chain": message.chain, + "type": message.type, + "sender": message.sender, + "channel": message.channel, + "confirmations": message.confirmations[0] if message.confirmations else None, + "confirmed": message.confirmed, + "signature": message.signature, + "size": message.size, + "time": message.time, + "item_type": message.item_type, + "item_content": message.item_content, + "hash_type": message.hash_type, + "content": message.content, + "forgotten_by": message.forgotten_by[0] if message.forgotten_by else None, + "tags": message.content.content.get("tags", None) + if hasattr(message.content, "content") + else None, + "key": message.content.key if hasattr(message.content, "key") else None, + "ref": message.content.ref if hasattr(message.content, "ref") else None, + "content_type": message.content.type + if hasattr(message.content, "type") + else None, + } + + +def model_to_message(item: Any) -> AlephMessage: + item.confirmations = [item.confirmations] if item.confirmations else [] + item.forgotten_by = [item.forgotten_by] if item.forgotten_by else None + + to_exclude = [ + MessageModel.tags, + MessageModel.ref, + MessageModel.key, + MessageModel.content_type, + ] + + item_dict = model_to_dict(item, exclude=to_exclude) + return parse_message(item_dict) + + +def query_field(field_name, field_values: Iterable[str]): + field = getattr(MessageModel, field_name) + values = list(field_values) + + if len(values) == 1: + return field == values[0] + return field.in_(values) + + +def get_message_query( + message_type: Optional[MessageType] = None, + content_keys: Optional[Iterable[str]] = None, + content_types: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, +): + query = MessageModel.select().order_by(MessageModel.time.desc()) + conditions = [] + if message_type: + conditions.append(query_field("type", [message_type.value])) + if content_keys: + conditions.append(query_field("key", content_keys)) + if content_types: + conditions.append(query_field("content_type", content_types)) + if refs: + conditions.append(query_field("ref", refs)) + if addresses: + conditions.append(query_field("sender", addresses)) + if tags: + for tag in tags: + conditions.append(MessageModel.tags.contains(tag)) + if hashes: + conditions.append(query_field("item_hash", hashes)) + if channels: + conditions.append(query_field("channel", channels)) + if chains: + conditions.append(query_field("chain", chains)) + if start_date: + conditions.append(MessageModel.time >= start_date) + if end_date: + conditions.append(MessageModel.time <= end_date) + + if conditions: + query = query.where(*conditions) + return query + + +class MessageCache(AlephClientBase): + """ + A wrapper around a sqlite3 database for caching AlephMessage objects. + + It can be used independently of a DomainNode to implement any kind of caching strategy. + """ + + _instance_count = 0 # Class-level counter for active instances + + def __init__(self): + if db.is_closed(): + db.connect() + if not MessageModel.table_exists(): + db.create_tables([MessageModel]) + + MessageCache._instance_count += 1 + + def __del__(self): + MessageCache._instance_count -= 1 + + if MessageCache._instance_count == 0: + db.close() + + def __getitem__(self, item_hash: Union[ItemHash, str]) -> Optional[AlephMessage]: + try: + item = MessageModel.get(MessageModel.item_hash == str(item_hash)) + except MessageModel.DoesNotExist: + return None + return model_to_message(item) + + def __delitem__(self, item_hash: Union[ItemHash, str]): + MessageModel.delete().where(MessageModel.item_hash == str(item_hash)).execute() + + def __contains__(self, item_hash: Union[ItemHash, str]) -> bool: + return ( + MessageModel.select() + .where(MessageModel.item_hash == str(item_hash)) + .exists() + ) + + def __len__(self): + return MessageModel.select().count() + + def __iter__(self) -> Iterator[AlephMessage]: + """ + Iterate over all messages in the cache, the latest first. + """ + for item in iter(MessageModel.select().order_by(-MessageModel.time)): + yield model_to_message(item) + + def __repr__(self) -> str: + return f"" + + def __str__(self) -> str: + return repr(self) + + @staticmethod + def add(messages: Union[AlephMessage, Iterable[AlephMessage]]): + if isinstance(messages, typing.get_args(AlephMessage)): + messages = [messages] + + data_source = (message_to_model(message) for message in messages) + MessageModel.insert_many(data_source).on_conflict_replace().execute() + + @staticmethod + def get( + item_hashes: Union[Union[ItemHash, str], Iterable[Union[ItemHash, str]]] + ) -> List[AlephMessage]: + """ + Get many messages from the cache by their item hash. + """ + if not isinstance(item_hashes, list): + item_hashes = [item_hashes] + item_hashes = [str(item_hash) for item_hash in item_hashes] + items = ( + MessageModel.select() + .where(MessageModel.item_hash.in_(item_hashes)) + .execute() + ) + return [model_to_message(item) for item in items] + + def listen_to(self, message_stream: AsyncIterable[AlephMessage]) -> Coroutine: + """ + Listen to a stream of messages and add them to the cache. + """ + + async def _listen(): + async for message in message_stream: + self.add(message) + print(f"Added message {message.item_hash} to cache") + + return _listen() + + async def fetch_aggregate( + self, address: str, key: str, limit: int = 100 + ) -> Dict[str, Dict]: + item = ( + MessageModel.select() + .where(MessageModel.type == MessageType.aggregate.value) + .where(MessageModel.sender == address) + .where(MessageModel.key == key) + .order_by(MessageModel.time.desc()) + .first() + ) + return item.content["content"] + + async def fetch_aggregates( + self, address: str, keys: Optional[Iterable[str]] = None, limit: int = 100 + ) -> Dict[str, Dict]: + query = ( + MessageModel.select() + .where(MessageModel.type == MessageType.aggregate.value) + .where(MessageModel.sender == address) + .order_by(MessageModel.time.desc()) + ) + if keys: + query = query.where(MessageModel.key.in_(keys)) + query = query.limit(limit) + return {item.key: item.content["content"] for item in list(query)} + + async def get_posts( + self, + pagination: int = 200, + page: int = 1, + types: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ignore_invalid_messages: bool = True, + invalid_messages_log_level: int = logging.NOTSET, + ) -> PostsResponse: + query = get_message_query( + message_type=MessageType.post, + content_types=types, + refs=refs, + addresses=addresses, + tags=tags, + hashes=hashes, + channels=channels, + chains=chains, + start_date=start_date, + end_date=end_date, + ) + + query = query.paginate(page, pagination) + + posts = [model_to_message(item) for item in list(query)] + + return PostsResponse( + posts=posts, + pagination_page=page, + pagination_per_page=pagination, + pagination_total=query.count(), + pagination_item="posts", + ) + + async def download_file(self, file_hash: str) -> bytes: + raise NotImplementedError + + async def get_messages( + self, + pagination: int = 200, + page: int = 1, + message_type: Optional[MessageType] = None, + content_types: Optional[Iterable[str]] = None, + content_keys: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ignore_invalid_messages: bool = True, + invalid_messages_log_level: int = logging.NOTSET, + ) -> MessagesResponse: + """ + Get many messages from the cache. + """ + query = get_message_query( + message_type=message_type, + content_keys=content_keys, + content_types=content_types, + refs=refs, + addresses=addresses, + tags=tags, + hashes=hashes, + channels=channels, + chains=chains, + start_date=start_date, + end_date=end_date, + ) + + query = query.paginate(page, pagination) + + messages = [model_to_message(item) for item in list(query)] + + return MessagesResponse( + messages=messages, + pagination_page=page, + pagination_per_page=pagination, + pagination_total=query.count(), + pagination_item="messages", + ) + + async def get_message( + self, + item_hash: str, + message_type: Optional[Type[GenericMessage]] = None, + channel: Optional[str] = None, + ) -> GenericMessage: + """ + Get a single message from the cache. + """ + query = MessageModel.select().where(MessageModel.item_hash == item_hash) + + if message_type: + query = query.where(MessageModel.type == message_type.value) + if channel: + query = query.where(MessageModel.channel == channel) + + item = query.first() + + if item: + return model_to_message(item) + + raise MessageNotFoundError(f"No such hash {item_hash}") + + async def watch_messages( + self, + message_type: Optional[MessageType] = None, + content_types: Optional[Iterable[str]] = None, + content_keys: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ) -> AsyncIterable[AlephMessage]: + """ + Watch messages from the cache. + """ + query = get_message_query( + message_type=message_type, + content_keys=content_keys, + content_types=content_types, + refs=refs, + addresses=addresses, + tags=tags, + hashes=hashes, + channels=channels, + chains=chains, + start_date=start_date, + end_date=end_date, + ) + + async for item in query: + yield model_to_message(item) + + +class DomainNode(MessageCache, AuthenticatedAlephClientBase): + """ + A Domain Node is a queryable proxy for Aleph Messages that are stored in a database cache and/or in the Aleph + network. + + It synchronizes with the network on a subset of the messages by listening to the network and storing the + messages in the cache. The user may define the subset by specifying a channels, tags, senders, chains, + message types, and/or a time window. + """ + + def __init__( + self, + session: AuthenticatedAlephClient, + channels: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + chains: Optional[Iterable[Chain]] = None, + message_type: Optional[MessageType] = None, + ): + super().__init__() + self.session = session + self.channels = channels + self.tags = tags + self.addresses = addresses + self.chains = chains + self.message_type = message_type + + # start listening to the network and storing messages in the cache + asyncio.get_event_loop().create_task( + self.listen_to( + self.session.watch_messages( + channels=self.channels, + tags=self.tags, + addresses=self.addresses, + chains=self.chains, + message_type=self.message_type, + ) + ) + ) + + # synchronize with past messages + asyncio.get_event_loop().run_until_complete( + self.synchronize( + channels=self.channels, + tags=self.tags, + addresses=self.addresses, + chains=self.chains, + message_type=self.message_type, + ) + ) + + async def __aenter__(self) -> "DomainNode": + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + ... + + async def synchronize( + self, + channels: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + chains: Optional[Iterable[Chain]] = None, + message_type: Optional[MessageType] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ): + """ + Synchronize with past messages. + """ + chunk_size = 200 + messages = [] + async for message in self.session.get_messages_iterator( + channels=channels, + tags=tags, + addresses=addresses, + chains=chains, + message_type=message_type, + start_date=start_date, + end_date=end_date, + ): + messages.append(message) + if len(messages) >= chunk_size: + self.add(messages) + messages = [] + if messages: + self.add(messages) + + async def download_file(self, file_hash: str) -> bytes: + """ + Opens a file that has been locally stored by its hash. + """ + try: + with open(self._file_path(file_hash), "rb") as f: + return f.read() + except FileNotFoundError: + file = await self.session.download_file(file_hash) + self._file_path(file_hash).parent.mkdir(parents=True, exist_ok=True) + with open(self._file_path(file_hash), "wb") as f: + f.write(file) + return file + + @staticmethod + def _file_path(file_hash: str) -> Path: + return settings.CACHE_FILES_PATH / Path(file_hash) + + async def create_post( + self, + post_content: Any, + post_type: str, + ref: Optional[str] = None, + address: Optional[str] = None, + channel: Optional[str] = None, + inline: bool = True, + storage_engine: StorageEnum = StorageEnum.storage, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + resp, status = await self.session.create_post( + post_content=post_content, + post_type=post_type, + ref=ref, + address=address, + channel=channel, + inline=inline, + storage_engine=storage_engine, + sync=sync, + ) + # WARNING: this can cause inconsistencies if the message is dropped/rejected by the aleph node + if status in [MessageStatus.PENDING, MessageStatus.PROCESSED]: + self.add(resp) + return resp, status + + async def create_aggregate( + self, + key: str, + content: Mapping[str, Any], + address: Optional[str] = None, + channel: Optional[str] = None, + inline: bool = True, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + resp, status = await self.session.create_aggregate( + key=key, + content=content, + address=address, + channel=channel, + inline=inline, + sync=sync, + ) + if status in [MessageStatus.PENDING, MessageStatus.PROCESSED]: + self.add(resp) + return resp, status + + async def create_store( + self, + address: Optional[str] = None, + file_content: Optional[bytes] = None, + file_path: Optional[Union[str, Path]] = None, + file_hash: Optional[str] = None, + guess_mime_type: bool = False, + ref: Optional[str] = None, + storage_engine: StorageEnum = StorageEnum.storage, + extra_fields: Optional[dict] = None, + channel: Optional[str] = None, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + resp, status = await self.session.create_store( + address=address, + file_content=file_content, + file_path=file_path, + file_hash=file_hash, + guess_mime_type=guess_mime_type, + ref=ref, + storage_engine=storage_engine, + extra_fields=extra_fields, + channel=channel, + sync=sync, + ) + if status in [MessageStatus.PENDING, MessageStatus.PROCESSED]: + self.add(resp) + return resp, status + + async def create_program( + self, + program_ref: str, + entrypoint: str, + runtime: str, + environment_variables: Optional[Mapping[str, str]] = None, + storage_engine: StorageEnum = StorageEnum.storage, + channel: Optional[str] = None, + address: Optional[str] = None, + sync: bool = False, + memory: Optional[int] = None, + vcpus: Optional[int] = None, + timeout_seconds: Optional[float] = None, + persistent: bool = False, + encoding: Encoding = Encoding.zip, + volumes: Optional[List[Mapping]] = None, + subscriptions: Optional[List[Mapping]] = None, + metadata: Optional[Mapping[str, Any]] = None, + ) -> Tuple[AlephMessage, MessageStatus]: + resp, status = await self.session.create_program( + program_ref=program_ref, + entrypoint=entrypoint, + runtime=runtime, + environment_variables=environment_variables, + storage_engine=storage_engine, + channel=channel, + address=address, + sync=sync, + memory=memory, + vcpus=vcpus, + timeout_seconds=timeout_seconds, + persistent=persistent, + encoding=encoding, + volumes=volumes, + subscriptions=subscriptions, + metadata=metadata, + ) + if status in [MessageStatus.PENDING, MessageStatus.PROCESSED]: + self.add(resp) + return resp, status + + async def forget( + self, + hashes: List[str], + reason: Optional[str], + storage_engine: StorageEnum = StorageEnum.storage, + channel: Optional[str] = None, + address: Optional[str] = None, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + resp, status = await self.session.forget( + hashes=hashes, + reason=reason, + storage_engine=storage_engine, + channel=channel, + address=address, + sync=sync, + ) + del self[resp.item_hash] + return resp, status + + async def submit( + self, + content: Dict[str, Any], + message_type: MessageType, + channel: Optional[str] = None, + storage_engine: StorageEnum = StorageEnum.storage, + allow_inlining: bool = True, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + resp, status = await self.session.submit( + content=content, + message_type=message_type, + channel=channel, + storage_engine=storage_engine, + allow_inlining=allow_inlining, + sync=sync, + ) + # WARNING: this can cause inconsistencies if the message is dropped/rejected by the aleph node + if status in [MessageStatus.PROCESSED, MessageStatus.PENDING]: + self.add(resp) + return resp, status diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 4f62c0c5..5a677341 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,8 +1,10 @@ import json from pathlib import Path from tempfile import NamedTemporaryFile +from typing import List import pytest as pytest +from aleph_message.models import AggregateMessage, AlephMessage, PostMessage import aleph.sdk.chains.ethereum as ethereum import aleph.sdk.chains.sol as solana @@ -50,3 +52,71 @@ def messages(): messages_path = Path(__file__).parent / "messages.json" with open(messages_path) as f: return json.load(f) + + +@pytest.fixture +def messages() -> List[AlephMessage]: + return [ + AggregateMessage.parse_obj( + { + "item_hash": "5b26d949fe05e38f535ef990a89da0473f9d700077cced228f2d36e73fca1fd6", + "type": "AGGREGATE", + "chain": "ETH", + "sender": "0x51A58800b26AA1451aaA803d1746687cB88E0501", + "signature": "0xca5825b6b93390482b436cb7f28b4628f8c9f56dc6af08260c869b79dd6017c94248839bd9fd0ffa1230dc3b1f4f7572a8d1f6fed6c6e1fb4d70ccda0ab5d4f21b", + "item_type": "inline", + "item_content": '{"address":"0x51A58800b26AA1451aaA803d1746687cB88E0501","key":"0xce844d79e5c0c325490c530aa41e8f602f0b5999binance","content":{"1692026263168":{"version":"x25519-xsalsa20-poly1305","nonce":"RT4Lbqs7Xzk+op2XC+VpXgwOgg21BotN","ephemPublicKey":"CVW8ECE3m8BepytHMTLan6/jgIfCxGdnKmX47YirF08=","ciphertext":"VuGJ9vMkJSbaYZCCv6Zemx4ixeb+9IW8H1vFB9vLtz1a8d87R4BfYUisLoCQxRkeUXqfW0/KIGQ5idVjr8Yj7QnKglW5AJ8UX7wEWMhiRFLatpWP8P9FI2n8Z7Rblu7Oz/OeKnuljKL3KsalcUQSsFa/1qACsIoycPZ6Wq6t1mXxVxxJWzClLyKRihv1pokZGT9UWxh7+tpoMGlRdYainyAt0/RygFw+r8iCMOilHnyv4ndLkKQJXyttb0tdNr/gr57+9761+trioGSysLQKZQWW6Ih6aE8V9t3BenfzYwiCnfFw3YAAKBPMdm9QdIETyrOi7YhD/w==","sha256":"bbeb499f681aed2bc18b6f3b6a30d25254bd30fbfde43444e9085f3bcd075c3c"}},"time":1692026263.662}', + "content": { + "key": "0xce844d79e5c0c325490c530aa41e8f602f0b5999binance", + "time": 1692026263.662, + "address": "0x51A58800b26AA1451aaA803d1746687cB88E0501", + "content": { + "hello": "world", + }, + }, + "time": 1692026263.662, + "channel": "UNSLASHED", + "size": 734, + "confirmations": [], + "confirmed": False, + } + ), + PostMessage.parse_obj( + { + "item_hash": "70f3798fdc68ce0ee03715a5547ee24e2c3e259bf02e3f5d1e4bf5a6f6a5e99f", + "type": "POST", + "chain": "SOL", + "sender": "0x4D52380D3191274a04846c89c069E6C3F2Ed94e4", + "signature": "0x91616ee45cfba55742954ff87ebf86db4988bcc5e3334b49a4caa6436e28e28d4ab38667cbd4bfb8903abf8d71f70d9ceb2c0a8d0a15c04fc1af5657f0050c101b", + "item_type": "storage", + "item_content": None, + "content": { + "time": 1692026021.1257718, + "type": "aleph-network-metrics", + "address": "0x4D52380D3191274a04846c89c069E6C3F2Ed94e4", + "ref": "0123456789abcdef", + "content": { + "tags": ["mainnet"], + "hello": "world", + "version": "1.0", + }, + }, + "time": 1692026021.132849, + "channel": "aleph-scoring", + "size": 122537, + "confirmations": [], + "confirmed": False, + } + ), + ] + + +@pytest.fixture +def raw_messages_response(messages): + return { + "messages": [message.dict() for message in messages], + "pagination_item": "messages", + "pagination_page": 1, + "pagination_per_page": 20, + "pagination_total": 2, + } diff --git a/tests/unit/test_node.py b/tests/unit/test_node.py new file mode 100644 index 00000000..0b844e50 --- /dev/null +++ b/tests/unit/test_node.py @@ -0,0 +1,255 @@ +import json +import os +from pathlib import Path +from typing import Any, Dict +from unittest.mock import AsyncMock, MagicMock + +import pytest as pytest +from aleph_message.models import ( + AggregateMessage, + ForgetMessage, + MessageType, + PostMessage, + ProgramMessage, + StoreMessage, +) +from aleph_message.status import MessageStatus + +from aleph.sdk import AuthenticatedAlephClient +from aleph.sdk.conf import settings +from aleph.sdk.node import DomainNode +from aleph.sdk.types import Account, StorageEnum + + +class MockPostResponse: + def __init__(self, response_message: Any, sync: bool): + self.response_message = response_message + self.sync = sync + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + ... + + @property + def status(self): + return 200 if self.sync else 202 + + def raise_for_status(self): + if self.status not in [200, 202]: + raise Exception("Bad status code") + + async def json(self): + message_status = "processed" if self.sync else "pending" + return { + "message_status": message_status, + "publication_status": {"status": "success", "failed": []}, + "hash": "QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy", + "message": self.response_message, + } + + async def text(self): + return json.dumps(await self.json()) + + +class MockGetResponse: + def __init__(self, response): + self.response = response + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + ... + + @property + def status(self): + return 200 + + def raise_for_status(self): + if self.status != 200: + raise Exception("Bad status code") + + async def json(self): + return self.response + + +@pytest.fixture +def mock_session_with_two_messages( + ethereum_account: Account, raw_messages_response: Dict[str, Any] +) -> AuthenticatedAlephClient: + http_session = AsyncMock() + http_session.post = MagicMock() + http_session.post.side_effect = lambda *args, **kwargs: MockPostResponse( + response_message={ + "type": "post", + "channel": "TEST", + "content": {"Hello": "World"}, + "key": "QmBlahBlahBlah", + "item_hash": "QmBlahBlahBlah", + }, + sync=kwargs.get("sync", False), + ) + http_session.get = MagicMock() + http_session.get.return_value = MockGetResponse(raw_messages_response) + + client = AuthenticatedAlephClient( + account=ethereum_account, api_server="http://localhost" + ) + client.http_session = http_session + + return client + + +@pytest.mark.asyncio +def test_node_init(mock_session_with_two_messages): + node = DomainNode(session=mock_session_with_two_messages) + assert node.session == mock_session_with_two_messages + assert len(node) >= 2 + + +@pytest.fixture +def mock_node_with_post_success(mock_session_with_two_messages) -> DomainNode: + node = DomainNode(session=mock_session_with_two_messages) + return node + + +@pytest.mark.asyncio +async def test_create_post(mock_node_with_post_success): + async with mock_node_with_post_success as session: + content = {"Hello": "World"} + + post_message, message_status = await session.create_post( + post_content=content, + post_type="TEST", + channel="TEST", + sync=False, + ) + + assert mock_node_with_post_success.session.http_session.post.called_once + assert isinstance(post_message, PostMessage) + assert message_status == MessageStatus.PENDING + + +@pytest.mark.asyncio +async def test_create_aggregate(mock_node_with_post_success): + async with mock_node_with_post_success as session: + aggregate_message, message_status = await session.create_aggregate( + key="hello", + content={"Hello": "world"}, + channel="TEST", + ) + + assert mock_node_with_post_success.session.http_session.post.called_once + assert isinstance(aggregate_message, AggregateMessage) + + +@pytest.mark.asyncio +async def test_create_store(mock_node_with_post_success): + mock_ipfs_push_file = AsyncMock() + mock_ipfs_push_file.return_value = "QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy" + + mock_node_with_post_success.ipfs_push_file = mock_ipfs_push_file + + async with mock_node_with_post_success as node: + _ = await node.create_store( + file_content=b"HELLO", + channel="TEST", + storage_engine=StorageEnum.ipfs, + ) + + _ = await node.create_store( + file_hash="QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy", + channel="TEST", + storage_engine=StorageEnum.ipfs, + ) + + mock_storage_push_file = AsyncMock() + mock_storage_push_file.return_value = ( + "QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy" + ) + mock_node_with_post_success.storage_push_file = mock_storage_push_file + async with mock_node_with_post_success as node: + store_message, message_status = await node.create_store( + file_content=b"HELLO", + channel="TEST", + storage_engine=StorageEnum.storage, + ) + + assert mock_node_with_post_success.session.http_session.post.called + assert isinstance(store_message, StoreMessage) + + +@pytest.mark.asyncio +async def test_create_program(mock_node_with_post_success): + async with mock_node_with_post_success as node: + program_message, message_status = await node.create_program( + program_ref="cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe", + entrypoint="main:app", + runtime="facefacefacefacefacefacefacefacefacefacefacefacefacefacefaceface", + channel="TEST", + metadata={"tags": ["test"]}, + ) + + assert mock_node_with_post_success.session.http_session.post.called_once + assert isinstance(program_message, ProgramMessage) + + +@pytest.mark.asyncio +async def test_forget(mock_node_with_post_success): + async with mock_node_with_post_success as node: + forget_message, message_status = await node.forget( + hashes=["QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy"], + reason="GDPR", + channel="TEST", + ) + + assert mock_node_with_post_success.session.http_session.post.called_once + assert isinstance(forget_message, ForgetMessage) + + +@pytest.mark.asyncio +async def test_download_file(mock_node_with_post_success): + mock_node_with_post_success.session.download_file = AsyncMock() + mock_node_with_post_success.session.download_file.return_value = b"HELLO" + + # remove file locally + if os.path.exists(settings.CACHE_FILES_PATH / Path("QmAndSoOn")): + os.remove(settings.CACHE_FILES_PATH / Path("QmAndSoOn")) + + # fetch from mocked response + async with mock_node_with_post_success as node: + file_content = await node.download_file( + file_hash="QmAndSoOn", + ) + + assert mock_node_with_post_success.session.http_session.get.called_once + assert file_content == b"HELLO" + + # fetch cached + async with mock_node_with_post_success as node: + file_content = await node.download_file( + file_hash="QmAndSoOn", + ) + + assert file_content == b"HELLO" + + +@pytest.mark.asyncio +async def test_submit_message(mock_node_with_post_success): + content = {"Hello": "World"} + async with mock_node_with_post_success as node: + message, status = await node.submit( + content={ + "address": "0x1234567890123456789012345678901234567890", + "time": 1234567890, + "type": "TEST", + "content": content, + }, + message_type=MessageType.post, + ) + + assert mock_node_with_post_success.session.http_session.post.called_once + assert message.content.content == content + assert status == MessageStatus.PENDING diff --git a/tests/unit/test_node_get.py b/tests/unit/test_node_get.py new file mode 100644 index 00000000..48bff3b8 --- /dev/null +++ b/tests/unit/test_node_get.py @@ -0,0 +1,231 @@ +import json +from hashlib import sha256 +from typing import List + +import pytest +from aleph_message.models import ( + AlephMessage, + Chain, + MessageType, + PostContent, + PostMessage, +) + +from aleph.sdk.chains.ethereum import get_fallback_account +from aleph.sdk.exceptions import MessageNotFoundError +from aleph.sdk.node import MessageCache + + +@pytest.mark.asyncio +async def test_base(messages): + # test add_many + cache = MessageCache() + cache.add(messages) + assert len(cache) == len(messages) + + item_hashes = [message.item_hash for message in messages] + cached_messages = cache.get(item_hashes) + assert len(cached_messages) == len(messages) + + for message in messages: + assert cache[message.item_hash] == message + + for message in messages: + assert message.item_hash in cache + + for message in cache: + del cache[message.item_hash] + assert message.item_hash not in cache + + assert len(cache) == 0 + del cache + + +class TestMessageQueries: + messages: List[AlephMessage] + cache: MessageCache + + @pytest.fixture(autouse=True) + def class_setup(self, messages): + self.messages = messages + self.cache = MessageCache() + self.cache.add(self.messages) + + def class_teardown(self): + del self.cache + + @pytest.mark.asyncio + async def test_iterate(self): + assert len(self.cache) == len(self.messages) + for message in self.cache: + assert message in self.messages + + @pytest.mark.asyncio + async def test_addresses(self): + items = ( + await self.cache.get_messages(addresses=[self.messages[0].sender]) + ).messages + assert items[0] == self.messages[0] + + @pytest.mark.asyncio + async def test_tags(self): + assert ( + len((await self.cache.get_messages(tags=["thistagdoesnotexist"])).messages) + == 0 + ) + + @pytest.mark.asyncio + async def test_message_type(self): + assert (await self.cache.get_messages(message_type=MessageType.post)).messages[ + 0 + ] == self.messages[1] + + @pytest.mark.asyncio + async def test_refs(self): + assert ( + await self.cache.get_messages(refs=[self.messages[1].content.ref]) + ).messages[0] == self.messages[1] + + @pytest.mark.asyncio + async def test_hashes(self): + assert ( + await self.cache.get_messages(hashes=[self.messages[0].item_hash]) + ).messages[0] == self.messages[0] + + @pytest.mark.asyncio + async def test_pagination(self): + assert len((await self.cache.get_messages(pagination=1)).messages) == 1 + + @pytest.mark.asyncio + async def test_content_types(self): + assert ( + await self.cache.get_messages(content_types=[self.messages[1].content.type]) + ).messages[0] == self.messages[1] + + @pytest.mark.asyncio + async def test_channels(self): + assert ( + await self.cache.get_messages(channels=[self.messages[1].channel]) + ).messages[0] == self.messages[1] + + @pytest.mark.asyncio + async def test_chains(self): + assert ( + await self.cache.get_messages(chains=[self.messages[1].chain]) + ).messages[0] == self.messages[1] + + @pytest.mark.asyncio + async def test_content_keys(self): + assert ( + await self.cache.get_messages(content_keys=[self.messages[0].content.key]) + ).messages[0] == self.messages[0] + + +class TestPostQueries: + messages: List[AlephMessage] + cache: MessageCache + + @pytest.fixture(autouse=True) + def class_setup(self, messages): + self.messages = messages + self.cache = MessageCache() + self.cache.add(self.messages) + + def class_teardown(self): + del self.cache + + @pytest.mark.asyncio + async def test_addresses(self): + items = (await self.cache.get_posts(addresses=[self.messages[1].sender])).posts + assert items[0] == self.messages[1] + + @pytest.mark.asyncio + async def test_tags(self): + assert ( + len((await self.cache.get_posts(tags=["thistagdoesnotexist"])).posts) == 0 + ) + + @pytest.mark.asyncio + async def test_types(self): + assert ( + len((await self.cache.get_posts(types=["thistypedoesnotexist"])).posts) == 0 + ) + + @pytest.mark.asyncio + async def test_channels(self): + assert (await self.cache.get_posts(channels=[self.messages[1].channel])).posts[ + 0 + ] == self.messages[1] + + @pytest.mark.asyncio + async def test_chains(self): + assert (await self.cache.get_posts(chains=[self.messages[1].chain])).posts[ + 0 + ] == self.messages[1] + + +@pytest.mark.asyncio +async def test_message_cache_listener(): + async def mock_message_stream(): + for i in range(3): + content = PostContent( + content={"hello": f"world{i}"}, + type="test", + address=get_fallback_account().get_address(), + time=0, + ) + message = PostMessage( + sender=get_fallback_account().get_address(), + item_hash=sha256(json.dumps(content.dict()).encode()).hexdigest(), + chain=Chain.ETH.value, + type=MessageType.post.value, + item_type="inline", + time=0, + content=content, + item_content=json.dumps(content.dict()), + ) + yield message + + cache = MessageCache() + # test listener + coro = cache.listen_to(mock_message_stream()) + await coro + assert len(cache) >= 3 + + +@pytest.mark.asyncio +async def test_fetch_aggregate(messages): + cache = MessageCache() + cache.add(messages) + + aggregate = await cache.fetch_aggregate(messages[0].sender, messages[0].content.key) + + assert aggregate == messages[0].content.content + + +@pytest.mark.asyncio +async def test_fetch_aggregates(messages): + cache = MessageCache() + cache.add(messages) + + aggregates = await cache.fetch_aggregates(messages[0].sender) + + assert aggregates == {messages[0].content.key: messages[0].content.content} + + +@pytest.mark.asyncio +async def test_get_message(messages): + cache = MessageCache() + cache.add(messages) + + message: AlephMessage = await cache.get_message(messages[0].item_hash) + + assert message == messages[0] + + +@pytest.mark.asyncio +async def test_get_message_fail(): + cache = MessageCache() + + with pytest.raises(MessageNotFoundError): + await cache.get_message("0x1234567890123456789012345678901234567890") From 2e74a5e68e422d9d4466e6a2aa9d5d417ad8afaa Mon Sep 17 00:00:00 2001 From: mhh Date: Wed, 6 Sep 2023 21:03:39 +0200 Subject: [PATCH 03/17] fix redeclaration of messages in test --- tests/unit/conftest.py | 2 +- tests/unit/test_chain_ethereum.py | 4 ++-- tests/unit/test_chain_solana.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 5a677341..cd071744 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -48,7 +48,7 @@ def substrate_account() -> substrate.DOTAccount: @pytest.fixture -def messages(): +def json_messages(): messages_path = Path(__file__).parent / "messages.json" with open(messages_path) as f: return json.load(f) diff --git a/tests/unit/test_chain_ethereum.py b/tests/unit/test_chain_ethereum.py index dea58c69..9a602b3d 100644 --- a/tests/unit/test_chain_ethereum.py +++ b/tests/unit/test_chain_ethereum.py @@ -82,8 +82,8 @@ async def test_verify_signature(ethereum_account): @pytest.mark.asyncio -async def test_verify_signature_with_processed_message(ethereum_account, messages): - message = messages[1] +async def test_verify_signature_with_processed_message(ethereum_account, json_messages): + message = json_messages[1] verify_signature( message["signature"], message["sender"], get_verification_buffer(message) ) diff --git a/tests/unit/test_chain_solana.py b/tests/unit/test_chain_solana.py index 5088158a..07b67602 100644 --- a/tests/unit/test_chain_solana.py +++ b/tests/unit/test_chain_solana.py @@ -103,8 +103,8 @@ async def test_verify_signature(solana_account): @pytest.mark.asyncio -async def test_verify_signature_with_processed_message(solana_account, messages): - message = messages[0] +async def test_verify_signature_with_processed_message(solana_account, json_messages): + message = json_messages[0] signature = json.loads(message["signature"])["signature"] verify_signature(signature, message["sender"], get_verification_buffer(message)) From 16d2d23b53dd671e69cf6c0178da52ad22911b98 Mon Sep 17 00:00:00 2001 From: mhh Date: Wed, 6 Sep 2023 21:04:25 +0200 Subject: [PATCH 04/17] use posts API v1 instead of v0 --- src/aleph/sdk/client.py | 2 +- src/aleph/sdk/models.py | 26 +++++++++++++------------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/aleph/sdk/client.py b/src/aleph/sdk/client.py index f79f0ceb..4f898b72 100644 --- a/src/aleph/sdk/client.py +++ b/src/aleph/sdk/client.py @@ -617,7 +617,7 @@ async def get_posts( end_date = end_date.timestamp() params["endDate"] = end_date - async with self.http_session.get("/api/v0/posts.json", params=params) as resp: + async with self.http_session.get("/api/v1/posts.json", params=params) as resp: resp.raise_for_status() response_json = await resp.json() posts_raw = response_json["posts"] diff --git a/src/aleph/sdk/models.py b/src/aleph/sdk/models.py index f5b1072b..f8cdec9d 100644 --- a/src/aleph/sdk/models.py +++ b/src/aleph/sdk/models.py @@ -1,6 +1,7 @@ -from typing import Any, Dict, List, Optional, Union +from datetime import datetime +from typing import Any, Dict, List, Optional -from aleph_message.models import AlephMessage, BaseMessage, ChainRef, ItemHash +from aleph_message.models import AlephMessage, ItemHash from pydantic import BaseModel, Field @@ -18,29 +19,28 @@ class MessagesResponse(PaginationResponse): pagination_item = "messages" -class Post(BaseMessage): +class Post(BaseModel): """ A post is a type of message that can be updated. Over the get_posts API we get the latest version of a post. """ - hash: ItemHash = Field(description="Hash of the content (sha256 by default)") + item_hash: ItemHash = Field(description="Hash of the content (sha256 by default)") + content: Dict[str, Any] = Field( + description="The content.content of the POST message" + ) original_item_hash: ItemHash = Field( description="Hash of the original content (sha256 by default)" ) - original_signature: Optional[str] = Field( - description="Cryptographic signature of the original message by the sender" - ) original_type: str = Field( description="The original, user-generated 'content-type' of the POST message" ) - content: Dict[str, Any] = Field( - description="The content.content of the POST message" - ) - type: str = Field(description="The content.type of the POST message") address: str = Field(description="The address of the sender of the POST message") - ref: Optional[Union[str, ChainRef]] = Field( - description="Other message referenced by this one" + ref: Optional[str] = Field(description="Other message referenced by this one") + channel: str = Field(description="The channel where the POST message was published") + created: datetime = Field(description="The time when the POST message was created") + last_updated: datetime = Field( + description="The time when the POST message was last updated" ) From 2c2eadfcdc69a3e7226b79ea74bf91db261e259f Mon Sep 17 00:00:00 2001 From: mhh Date: Wed, 6 Sep 2023 21:05:11 +0200 Subject: [PATCH 05/17] add posts table for caching posts; handle amend messages; refactor node.py as a package --- src/aleph/sdk/{node.py => node/__init__.py} | 285 ++++++-------------- src/aleph/sdk/node/common.py | 44 +++ src/aleph/sdk/node/message.py | 137 ++++++++++ src/aleph/sdk/node/post.py | 115 ++++++++ 4 files changed, 372 insertions(+), 209 deletions(-) rename src/aleph/sdk/{node.py => node/__init__.py} (71%) create mode 100644 src/aleph/sdk/node/common.py create mode 100644 src/aleph/sdk/node/message.py create mode 100644 src/aleph/sdk/node/post.py diff --git a/src/aleph/sdk/node.py b/src/aleph/sdk/node/__init__.py similarity index 71% rename from src/aleph/sdk/node.py rename to src/aleph/sdk/node/__init__.py index a9548e67..1477ac2c 100644 --- a/src/aleph/sdk/node.py +++ b/src/aleph/sdk/node/__init__.py @@ -1,16 +1,13 @@ import asyncio -import json import logging import typing from datetime import datetime -from functools import partial from pathlib import Path from typing import ( Any, AsyncIterable, Coroutine, Dict, - Generic, Iterable, Iterator, List, @@ -18,203 +15,26 @@ Optional, Tuple, Type, - TypeVar, Union, ) -from aleph_message import MessagesResponse, parse_message -from aleph_message.models import ( - AlephMessage, - Chain, - ItemHash, - MessageConfirmation, - MessageType, -) +from aleph_message import MessagesResponse +from aleph_message.models import AlephMessage, Chain, ItemHash, MessageType, PostMessage from aleph_message.models.execution.base import Encoding from aleph_message.status import MessageStatus -from peewee import ( - BooleanField, - CharField, - FloatField, - IntegerField, - Model, - SqliteDatabase, -) -from playhouse.shortcuts import model_to_dict -from playhouse.sqlite_ext import JSONField -from pydantic import BaseModel - -from aleph.sdk import AuthenticatedAlephClient -from aleph.sdk.base import AlephClientBase, AuthenticatedAlephClientBase -from aleph.sdk.conf import settings -from aleph.sdk.exceptions import MessageNotFoundError -from aleph.sdk.models import PostsResponse -from aleph.sdk.types import GenericMessage, StorageEnum - -db = SqliteDatabase(settings.CACHE_DATABASE_PATH) -T = TypeVar("T", bound=BaseModel) - - -class JSONDictEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, BaseModel): - return obj.dict() - return json.JSONEncoder.default(self, obj) - - -pydantic_json_dumps = partial(json.dumps, cls=JSONDictEncoder) - - -class PydanticField(JSONField, Generic[T]): - """ - A field for storing pydantic model types as JSON in a database. Uses json for serialization. - """ - - type: T - def __init__(self, *args, **kwargs): - self.type = kwargs.pop("type") - super().__init__(*args, **kwargs) +from ..base import BaseAlephClient, BaseAuthenticatedAlephClient +from ..client import AuthenticatedAlephClient +from ..conf import settings +from ..exceptions import MessageNotFoundError +from ..models import PostsResponse +from ..types import GenericMessage, StorageEnum +from .common import db +from .message import MessageModel, get_message_query, message_to_model, model_to_message +from .post import PostModel, get_post_query, message_to_post, model_to_post - def db_value(self, value: Optional[T]) -> Optional[str]: - if value is None: - return None - return value.json() - - def python_value(self, value: Optional[str]) -> Optional[T]: - if value is None: - return None - return self.type.parse_raw(value) - - -class MessageModel(Model): - """ - A simple database model for storing AlephMessage objects. - """ - item_hash = CharField(primary_key=True) - chain = CharField(5) - type = CharField(9) - sender = CharField() - channel = CharField(null=True) - confirmations: PydanticField[MessageConfirmation] = PydanticField( - type=MessageConfirmation, null=True - ) - confirmed = BooleanField(null=True) - signature = CharField(null=True) - size = IntegerField(null=True) - time = FloatField() - item_type = CharField(7) - item_content = CharField(null=True) - hash_type = CharField(6, null=True) - content = JSONField(json_dumps=pydantic_json_dumps) - forgotten_by = CharField(null=True) - tags = JSONField(json_dumps=pydantic_json_dumps, null=True) - key = CharField(null=True) - ref = CharField(null=True) - content_type = CharField(null=True) - - class Meta: - database = db - - -def message_to_model(message: AlephMessage) -> Dict: - return { - "item_hash": str(message.item_hash), - "chain": message.chain, - "type": message.type, - "sender": message.sender, - "channel": message.channel, - "confirmations": message.confirmations[0] if message.confirmations else None, - "confirmed": message.confirmed, - "signature": message.signature, - "size": message.size, - "time": message.time, - "item_type": message.item_type, - "item_content": message.item_content, - "hash_type": message.hash_type, - "content": message.content, - "forgotten_by": message.forgotten_by[0] if message.forgotten_by else None, - "tags": message.content.content.get("tags", None) - if hasattr(message.content, "content") - else None, - "key": message.content.key if hasattr(message.content, "key") else None, - "ref": message.content.ref if hasattr(message.content, "ref") else None, - "content_type": message.content.type - if hasattr(message.content, "type") - else None, - } - - -def model_to_message(item: Any) -> AlephMessage: - item.confirmations = [item.confirmations] if item.confirmations else [] - item.forgotten_by = [item.forgotten_by] if item.forgotten_by else None - - to_exclude = [ - MessageModel.tags, - MessageModel.ref, - MessageModel.key, - MessageModel.content_type, - ] - - item_dict = model_to_dict(item, exclude=to_exclude) - return parse_message(item_dict) - - -def query_field(field_name, field_values: Iterable[str]): - field = getattr(MessageModel, field_name) - values = list(field_values) - - if len(values) == 1: - return field == values[0] - return field.in_(values) - - -def get_message_query( - message_type: Optional[MessageType] = None, - content_keys: Optional[Iterable[str]] = None, - content_types: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, -): - query = MessageModel.select().order_by(MessageModel.time.desc()) - conditions = [] - if message_type: - conditions.append(query_field("type", [message_type.value])) - if content_keys: - conditions.append(query_field("key", content_keys)) - if content_types: - conditions.append(query_field("content_type", content_types)) - if refs: - conditions.append(query_field("ref", refs)) - if addresses: - conditions.append(query_field("sender", addresses)) - if tags: - for tag in tags: - conditions.append(MessageModel.tags.contains(tag)) - if hashes: - conditions.append(query_field("item_hash", hashes)) - if channels: - conditions.append(query_field("channel", channels)) - if chains: - conditions.append(query_field("chain", chains)) - if start_date: - conditions.append(MessageModel.time >= start_date) - if end_date: - conditions.append(MessageModel.time <= end_date) - - if conditions: - query = query.where(*conditions) - return query - - -class MessageCache(AlephClientBase): +class MessageCache(BaseAlephClient): """ A wrapper around a sqlite3 database for caching AlephMessage objects. @@ -222,12 +42,16 @@ class MessageCache(AlephClientBase): """ _instance_count = 0 # Class-level counter for active instances + missing_posts: Dict[ItemHash, PostMessage] = {} + """A dict of all posts by item_hash and their amend messages that are missing from the cache.""" def __init__(self): if db.is_closed(): db.connect() if not MessageModel.table_exists(): db.create_tables([MessageModel]) + if not PostModel.table_exists(): + db.create_tables([PostModel]) MessageCache._instance_count += 1 @@ -270,17 +94,57 @@ def __repr__(self) -> str: def __str__(self) -> str: return repr(self) - @staticmethod - def add(messages: Union[AlephMessage, Iterable[AlephMessage]]): + def add(self, messages: Union[AlephMessage, Iterable[AlephMessage]]): if isinstance(messages, typing.get_args(AlephMessage)): messages = [messages] - data_source = (message_to_model(message) for message in messages) - MessageModel.insert_many(data_source).on_conflict_replace().execute() + message_data = (message_to_model(message) for message in messages) + MessageModel.insert_many(message_data).on_conflict_replace().execute() + + # Add posts and their amends to the PostModel + post_data = [] + amend_messages = [] + for message in messages: + if message.item_type != MessageType.post: + continue + if message.content.type == "amend": + amend_messages.append(message) + else: + post = message_to_post(message).dict() + post_data.append(post) + # Check if we can now add any amend messages that had missing refs + if message.item_hash in self.missing_posts: + amend_messages += self.missing_posts.pop(message.item_hash) + + PostModel.insert_many(post_data).on_conflict_replace().execute() + + # Handle amends in second step to avoid missing original posts + post_data = [] + for message in amend_messages: + # Find the original post and update it + original_post = MessageModel.get( + MessageModel.item_hash == message.content.ref + ) + if not original_post: + latest_amend = self.missing_posts.get(ItemHash(message.content.ref)) + if latest_amend and message.time < latest_amend.time: + self.missing_posts[ItemHash(message.content.ref)] = message + continue + if datetime.fromtimestamp(message.time) < original_post.last_updated: + continue + original_post.item_hash = message.item_hash + original_post.content = message.content.content + original_post.original_item_hash = message.content.ref + original_post.original_type = message.content.type + original_post.address = message.sender + original_post.channel = message.channel + original_post.last_updated = datetime.fromtimestamp(message.time) + post_data.append(original_post) + + PostModel.insert_many(post_data).on_conflict_replace().execute() - @staticmethod def get( - item_hashes: Union[Union[ItemHash, str], Iterable[Union[ItemHash, str]]] + self, item_hashes: Union[Union[ItemHash, str], Iterable[Union[ItemHash, str]]] ) -> List[AlephMessage]: """ Get many messages from the cache by their item hash. @@ -347,12 +211,11 @@ async def get_posts( chains: Optional[Iterable[str]] = None, start_date: Optional[Union[datetime, float]] = None, end_date: Optional[Union[datetime, float]] = None, - ignore_invalid_messages: bool = True, - invalid_messages_log_level: int = logging.NOTSET, + ignore_invalid_messages: Optional[bool] = True, + invalid_messages_log_level: Optional[int] = logging.NOTSET, ) -> PostsResponse: - query = get_message_query( - message_type=MessageType.post, - content_types=types, + query = get_post_query( + types=types, refs=refs, addresses=addresses, tags=tags, @@ -365,7 +228,7 @@ async def get_posts( query = query.paginate(page, pagination) - posts = [model_to_message(item) for item in list(query)] + posts = [model_to_post(item) for item in list(query)] return PostsResponse( posts=posts, @@ -383,6 +246,7 @@ async def get_messages( pagination: int = 200, page: int = 1, message_type: Optional[MessageType] = None, + message_types: Optional[Iterable[MessageType]] = None, content_types: Optional[Iterable[str]] = None, content_keys: Optional[Iterable[str]] = None, refs: Optional[Iterable[str]] = None, @@ -393,14 +257,15 @@ async def get_messages( chains: Optional[Iterable[str]] = None, start_date: Optional[Union[datetime, float]] = None, end_date: Optional[Union[datetime, float]] = None, - ignore_invalid_messages: bool = True, - invalid_messages_log_level: int = logging.NOTSET, + ignore_invalid_messages: Optional[bool] = True, + invalid_messages_log_level: Optional[int] = logging.NOTSET, ) -> MessagesResponse: """ Get many messages from the cache. """ + message_types = message_types or [message_type] if message_type else None query = get_message_query( - message_type=message_type, + message_types=message_types, content_keys=content_keys, content_types=content_types, refs=refs, @@ -451,6 +316,7 @@ async def get_message( async def watch_messages( self, message_type: Optional[MessageType] = None, + message_types: Optional[Iterable[MessageType]] = None, content_types: Optional[Iterable[str]] = None, content_keys: Optional[Iterable[str]] = None, refs: Optional[Iterable[str]] = None, @@ -465,8 +331,9 @@ async def watch_messages( """ Watch messages from the cache. """ + message_types = message_types or [message_type] if message_type else None query = get_message_query( - message_type=message_type, + message_types=message_types, content_keys=content_keys, content_types=content_types, refs=refs, @@ -483,7 +350,7 @@ async def watch_messages( yield model_to_message(item) -class DomainNode(MessageCache, AuthenticatedAlephClientBase): +class DomainNode(MessageCache, BaseAuthenticatedAlephClient): """ A Domain Node is a queryable proxy for Aleph Messages that are stored in a database cache and/or in the Aleph network. diff --git a/src/aleph/sdk/node/common.py b/src/aleph/sdk/node/common.py new file mode 100644 index 00000000..baed8b39 --- /dev/null +++ b/src/aleph/sdk/node/common.py @@ -0,0 +1,44 @@ +import json +from functools import partial +from typing import Generic, Optional, TypeVar + +from peewee import SqliteDatabase +from playhouse.sqlite_ext import JSONField +from pydantic import BaseModel + +from aleph.sdk.conf import settings + +db = SqliteDatabase(settings.CACHE_DATABASE_PATH) +T = TypeVar("T", bound=BaseModel) + + +class JSONDictEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, BaseModel): + return obj.dict() + return json.JSONEncoder.default(self, obj) + + +pydantic_json_dumps = partial(json.dumps, cls=JSONDictEncoder) + + +class PydanticField(JSONField, Generic[T]): + """ + A field for storing pydantic model types as JSON in a database. Uses json for serialization. + """ + + type: T + + def __init__(self, *args, **kwargs): + self.type = kwargs.pop("type") + super().__init__(*args, **kwargs) + + def db_value(self, value: Optional[T]) -> Optional[str]: + if value is None: + return None + return value.json() + + def python_value(self, value: Optional[str]) -> Optional[T]: + if value is None: + return None + return self.type.parse_raw(value) diff --git a/src/aleph/sdk/node/message.py b/src/aleph/sdk/node/message.py new file mode 100644 index 00000000..a3327d2a --- /dev/null +++ b/src/aleph/sdk/node/message.py @@ -0,0 +1,137 @@ +from datetime import datetime +from typing import Any, Dict, Iterable, Optional, Union + +from aleph_message import parse_message +from aleph_message.models import AlephMessage, MessageConfirmation, MessageType +from peewee import BooleanField, CharField, FloatField, IntegerField, Model +from playhouse.shortcuts import model_to_dict +from playhouse.sqlite_ext import JSONField + +from aleph.sdk.node.common import PydanticField, db, pydantic_json_dumps + + +class MessageModel(Model): + """ + A simple database model for storing AlephMessage objects. + """ + + item_hash = CharField(primary_key=True) + chain = CharField(5) + type = CharField(9) + sender = CharField() + channel = CharField(null=True) + confirmations: PydanticField[MessageConfirmation] = PydanticField( + type=MessageConfirmation, null=True + ) + confirmed = BooleanField(null=True) + signature = CharField(null=True) + size = IntegerField(null=True) + time = FloatField() + item_type = CharField(7) + item_content = CharField(null=True) + hash_type = CharField(6, null=True) + content = JSONField(json_dumps=pydantic_json_dumps) + forgotten_by = CharField(null=True) + tags = JSONField(json_dumps=pydantic_json_dumps, null=True) + key = CharField(null=True) + ref = CharField(null=True) + content_type = CharField(null=True) + + class Meta: + database = db + + +def message_to_model(message: AlephMessage) -> Dict: + return { + "item_hash": str(message.item_hash), + "chain": message.chain, + "type": message.type, + "sender": message.sender, + "channel": message.channel, + "confirmations": message.confirmations[0] if message.confirmations else None, + "confirmed": message.confirmed, + "signature": message.signature, + "size": message.size, + "time": message.time, + "item_type": message.item_type, + "item_content": message.item_content, + "hash_type": message.hash_type, + "content": message.content, + "forgotten_by": message.forgotten_by[0] if message.forgotten_by else None, + "tags": message.content.content.get("tags", None) + if hasattr(message.content, "content") + else None, + "key": message.content.key if hasattr(message.content, "key") else None, + "ref": message.content.ref if hasattr(message.content, "ref") else None, + "content_type": message.content.type + if hasattr(message.content, "type") + else None, + } + + +def model_to_message(item: Any) -> AlephMessage: + item.confirmations = [item.confirmations] if item.confirmations else [] + item.forgotten_by = [item.forgotten_by] if item.forgotten_by else None + + to_exclude = [ + MessageModel.tags, + MessageModel.ref, + MessageModel.key, + MessageModel.content_type, + ] + + item_dict = model_to_dict(item, exclude=to_exclude) + return parse_message(item_dict) + + +def query_field(field_name, field_values: Iterable[str]): + field = getattr(MessageModel, field_name) + values = list(field_values) + + if len(values) == 1: + return field == values[0] + return field.in_(values) + + +def get_message_query( + message_types: Optional[Iterable[MessageType]] = None, + content_keys: Optional[Iterable[str]] = None, + content_types: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, +): + query = MessageModel.select().order_by(MessageModel.time.desc()) + conditions = [] + if message_types: + conditions.append(query_field("type", [type.value for type in message_types])) + if content_keys: + conditions.append(query_field("key", content_keys)) + if content_types: + conditions.append(query_field("content_type", content_types)) + if refs: + conditions.append(query_field("ref", refs)) + if addresses: + conditions.append(query_field("sender", addresses)) + if tags: + for tag in tags: + conditions.append(MessageModel.tags.contains(tag)) + if hashes: + conditions.append(query_field("item_hash", hashes)) + if channels: + conditions.append(query_field("channel", channels)) + if chains: + conditions.append(query_field("chain", chains)) + if start_date: + conditions.append(MessageModel.time >= start_date) + if end_date: + conditions.append(MessageModel.time <= end_date) + + if conditions: + query = query.where(*conditions) + return query diff --git a/src/aleph/sdk/node/post.py b/src/aleph/sdk/node/post.py new file mode 100644 index 00000000..b68a421d --- /dev/null +++ b/src/aleph/sdk/node/post.py @@ -0,0 +1,115 @@ +from datetime import datetime +from typing import Any, Dict, Iterable, Optional, Union + +from aleph_message.models import PostMessage +from peewee import CharField, DateTimeField, Model +from playhouse.shortcuts import model_to_dict +from playhouse.sqlite_ext import JSONField + +from aleph.sdk.models import Post +from aleph.sdk.node.common import db, pydantic_json_dumps + + +class PostModel(Model): + """ + A simple database model for storing AlephMessage objects. + """ + + original_item_hash = CharField(primary_key=True) + item_hash = CharField() + content = JSONField(json_dumps=pydantic_json_dumps) + original_type = CharField() + address = CharField() + ref = CharField(null=True) + channel = CharField(null=True) + created = DateTimeField() + last_updated = DateTimeField() + tags = JSONField(json_dumps=pydantic_json_dumps, null=True) + chain = CharField(5) + + class Meta: + database = db + + +def post_to_model(post: Post) -> Dict: + return { + "item_hash": str(post.item_hash), + "content": post.content, + "original_item_hash": str(post.original_item_hash), + "original_type": post.original_type, + "address": post.address, + "ref": post.ref, + "channel": post.channel, + "created": post.created, + "last_updated": post.last_updated, + } + + +def message_to_post(message: PostMessage) -> Post: + return Post.parse_obj( + { + "item_hash": str(message.item_hash), + "content": message.content, + "original_item_hash": str(message.item_hash), + "original_type": message.content.type + if hasattr(message.content, "type") + else None, + "address": message.sender, + "ref": message.content.ref if hasattr(message.content, "ref") else None, + "channel": message.channel, + "created": datetime.fromtimestamp(message.time), + "last_updated": datetime.fromtimestamp(message.time), + } + ) + + +def model_to_post(item: Any) -> Post: + to_exclude = [PostModel.tags, PostModel.chain] + return Post.parse_obj(model_to_dict(item, exclude=to_exclude)) + + +def query_field(field_name, field_values: Iterable[str]): + field = getattr(PostModel, field_name) + values = list(field_values) + + if len(values) == 1: + return field == values[0] + return field.in_(values) + + +def get_post_query( + types: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, +): + query = PostModel.select().order_by(PostModel.created.desc()) + conditions = [] + if types: + conditions.append(query_field("original_type", types)) + if refs: + conditions.append(query_field("ref", refs)) + if addresses: + conditions.append(query_field("address", addresses)) + if tags: + for tag in tags: + conditions.append(PostModel.tags.contains(tag)) + if hashes: + conditions.append(query_field("item_hash", hashes)) + if channels: + conditions.append(query_field("channel", channels)) + if chains: + conditions.append(query_field("chain", chains)) + if start_date: + conditions.append(PostModel.time >= start_date) + if end_date: + conditions.append(PostModel.time <= end_date) + + if conditions: + query = query.where(*conditions) + return query From f4e71f13e5dadfadd42aa401f437763c6e381390 Mon Sep 17 00:00:00 2001 From: mhh Date: Wed, 6 Sep 2023 21:23:40 +0200 Subject: [PATCH 06/17] fix post problems --- src/aleph/sdk/models.py | 5 ++++- src/aleph/sdk/node/__init__.py | 18 +++++++++++------- src/aleph/sdk/node/post.py | 2 +- tests/unit/test_node_get.py | 9 +++++---- 4 files changed, 21 insertions(+), 13 deletions(-) diff --git a/src/aleph/sdk/models.py b/src/aleph/sdk/models.py index f8cdec9d..3de92c74 100644 --- a/src/aleph/sdk/models.py +++ b/src/aleph/sdk/models.py @@ -37,12 +37,15 @@ class Post(BaseModel): ) address: str = Field(description="The address of the sender of the POST message") ref: Optional[str] = Field(description="Other message referenced by this one") - channel: str = Field(description="The channel where the POST message was published") + channel: Optional[str] = Field(description="The channel where the POST message was published") created: datetime = Field(description="The time when the POST message was created") last_updated: datetime = Field( description="The time when the POST message was last updated" ) + class Config: + allow_extra = False + class PostsResponse(PaginationResponse): """Response from an Aleph node API on the path /api/v0/posts.json""" diff --git a/src/aleph/sdk/node/__init__.py b/src/aleph/sdk/node/__init__.py index 1477ac2c..f95256cc 100644 --- a/src/aleph/sdk/node/__init__.py +++ b/src/aleph/sdk/node/__init__.py @@ -98,6 +98,8 @@ def add(self, messages: Union[AlephMessage, Iterable[AlephMessage]]): if isinstance(messages, typing.get_args(AlephMessage)): messages = [messages] + messages = list(messages) + message_data = (message_to_model(message) for message in messages) MessageModel.insert_many(message_data).on_conflict_replace().execute() @@ -105,16 +107,18 @@ def add(self, messages: Union[AlephMessage, Iterable[AlephMessage]]): post_data = [] amend_messages = [] for message in messages: - if message.item_type != MessageType.post: + if message.type != MessageType.post.value: continue if message.content.type == "amend": amend_messages.append(message) - else: - post = message_to_post(message).dict() - post_data.append(post) - # Check if we can now add any amend messages that had missing refs - if message.item_hash in self.missing_posts: - amend_messages += self.missing_posts.pop(message.item_hash) + continue + post = message_to_post(message).dict() + post["chain"] = message.chain.value + post["tags"] = message.content.content.get("tags", None) + post_data.append(post) + # Check if we can now add any amend messages that had missing refs + if message.item_hash in self.missing_posts: + amend_messages += self.missing_posts.pop(message.item_hash) PostModel.insert_many(post_data).on_conflict_replace().execute() diff --git a/src/aleph/sdk/node/post.py b/src/aleph/sdk/node/post.py index b68a421d..e4af0807 100644 --- a/src/aleph/sdk/node/post.py +++ b/src/aleph/sdk/node/post.py @@ -58,7 +58,7 @@ def message_to_post(message: PostMessage) -> Post: "ref": message.content.ref if hasattr(message.content, "ref") else None, "channel": message.channel, "created": datetime.fromtimestamp(message.time), - "last_updated": datetime.fromtimestamp(message.time), + "last_updated": datetime.fromtimestamp(message.time) } ) diff --git a/tests/unit/test_node_get.py b/tests/unit/test_node_get.py index 48bff3b8..5f80d2c6 100644 --- a/tests/unit/test_node_get.py +++ b/tests/unit/test_node_get.py @@ -13,7 +13,7 @@ from aleph.sdk.chains.ethereum import get_fallback_account from aleph.sdk.exceptions import MessageNotFoundError -from aleph.sdk.node import MessageCache +from aleph.sdk.node import MessageCache, message_to_post @pytest.mark.asyncio @@ -137,7 +137,7 @@ def class_teardown(self): @pytest.mark.asyncio async def test_addresses(self): items = (await self.cache.get_posts(addresses=[self.messages[1].sender])).posts - assert items[0] == self.messages[1] + assert items[0] == message_to_post(self.messages[1]) @pytest.mark.asyncio async def test_tags(self): @@ -153,15 +153,16 @@ async def test_types(self): @pytest.mark.asyncio async def test_channels(self): + print(self.messages[1]) assert (await self.cache.get_posts(channels=[self.messages[1].channel])).posts[ 0 - ] == self.messages[1] + ] == message_to_post(self.messages[1]) @pytest.mark.asyncio async def test_chains(self): assert (await self.cache.get_posts(chains=[self.messages[1].chain])).posts[ 0 - ] == self.messages[1] + ] == message_to_post(self.messages[1]) @pytest.mark.asyncio From f5261adf5560883f7225ae67b4159db7e0f3be84 Mon Sep 17 00:00:00 2001 From: mhh Date: Wed, 6 Sep 2023 21:44:18 +0200 Subject: [PATCH 07/17] fix testing problem; reformat --- src/aleph/sdk/models.py | 4 +++- src/aleph/sdk/node/post.py | 2 +- tests/unit/conftest.py | 10 +++++----- tests/unit/test_node.py | 12 ++++++++---- tests/unit/test_node_get.py | 3 ++- 5 files changed, 19 insertions(+), 12 deletions(-) diff --git a/src/aleph/sdk/models.py b/src/aleph/sdk/models.py index 3de92c74..40083bb2 100644 --- a/src/aleph/sdk/models.py +++ b/src/aleph/sdk/models.py @@ -37,7 +37,9 @@ class Post(BaseModel): ) address: str = Field(description="The address of the sender of the POST message") ref: Optional[str] = Field(description="Other message referenced by this one") - channel: Optional[str] = Field(description="The channel where the POST message was published") + channel: Optional[str] = Field( + description="The channel where the POST message was published" + ) created: datetime = Field(description="The time when the POST message was created") last_updated: datetime = Field( description="The time when the POST message was last updated" diff --git a/src/aleph/sdk/node/post.py b/src/aleph/sdk/node/post.py index e4af0807..b68a421d 100644 --- a/src/aleph/sdk/node/post.py +++ b/src/aleph/sdk/node/post.py @@ -58,7 +58,7 @@ def message_to_post(message: PostMessage) -> Post: "ref": message.content.ref if hasattr(message.content, "ref") else None, "channel": message.channel, "created": datetime.fromtimestamp(message.time), - "last_updated": datetime.fromtimestamp(message.time) + "last_updated": datetime.fromtimestamp(message.time), } ) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index cd071744..5d6d2388 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -113,10 +113,10 @@ def messages() -> List[AlephMessage]: @pytest.fixture def raw_messages_response(messages): - return { - "messages": [message.dict() for message in messages], + return lambda page: { + "messages": [message.dict() for message in messages] if page == 1 else [], "pagination_item": "messages", - "pagination_page": 1, - "pagination_per_page": 20, - "pagination_total": 2, + "pagination_page": page, + "pagination_per_page": max(len(messages), 20), + "pagination_total": len(messages) if page == 1 else 0, } diff --git a/tests/unit/test_node.py b/tests/unit/test_node.py index 0b844e50..6ea9d44c 100644 --- a/tests/unit/test_node.py +++ b/tests/unit/test_node.py @@ -54,8 +54,9 @@ async def text(self): class MockGetResponse: - def __init__(self, response): - self.response = response + def __init__(self, response_message, page=1): + self.response_message = response_message + self.page = page async def __aenter__(self): return self @@ -72,7 +73,7 @@ def raise_for_status(self): raise Exception("Bad status code") async def json(self): - return self.response + return self.response_message(self.page) @pytest.fixture @@ -92,7 +93,10 @@ def mock_session_with_two_messages( sync=kwargs.get("sync", False), ) http_session.get = MagicMock() - http_session.get.return_value = MockGetResponse(raw_messages_response) + http_session.get.side_effect = lambda *args, **kwargs: MockGetResponse( + response_message=raw_messages_response, + page=kwargs.get("params", {}).get("page", 1), + ) client = AuthenticatedAlephClient( account=ethereum_account, api_server="http://localhost" diff --git a/tests/unit/test_node_get.py b/tests/unit/test_node_get.py index 5f80d2c6..2326803a 100644 --- a/tests/unit/test_node_get.py +++ b/tests/unit/test_node_get.py @@ -13,7 +13,8 @@ from aleph.sdk.chains.ethereum import get_fallback_account from aleph.sdk.exceptions import MessageNotFoundError -from aleph.sdk.node import MessageCache, message_to_post +from aleph.sdk.node import MessageCache +from aleph.sdk.node.post import message_to_post @pytest.mark.asyncio From 7c6085a6fe05c3dc652eaecbc7af131a36cc4fdb Mon Sep 17 00:00:00 2001 From: mhh Date: Wed, 6 Sep 2023 21:46:36 +0200 Subject: [PATCH 08/17] remove problematic integration test that cannot use v1 posts --- tests/integration/itest_forget.py | 28 +--------------------------- 1 file changed, 1 insertion(+), 27 deletions(-) diff --git a/tests/integration/itest_forget.py b/tests/integration/itest_forget.py index 29b6c6d9..cf780ed7 100644 --- a/tests/integration/itest_forget.py +++ b/tests/integration/itest_forget.py @@ -100,31 +100,5 @@ async def test_forget_a_forget_message(fixture_account): """ Attempts to forget a forget message. This should fail. """ - # TODO: this test should be moved to the PyAleph API tests, once a framework is in place. - post_hash = await create_and_forget_post(fixture_account, TARGET_NODE, TARGET_NODE) - async with AuthenticatedAlephClient( - account=fixture_account, api_server=TARGET_NODE - ) as session: - get_post_response = await session.get_posts(hashes=[post_hash]) - assert len(get_post_response.posts) == 1 - post = get_post_response.posts[0] - - forget_message_hash = post.forgotten_by[0] - forget_message, forget_status = await session.forget( - hashes=[forget_message_hash], - reason="I want to remember this post. Maybe I can forget I forgot it?", - channel=TEST_CHANNEL, - ) - - print(forget_message) - - get_forget_message_response = await session.get_messages( - hashes=[forget_message_hash], - channels=[TEST_CHANNEL], - ) - assert len(get_forget_message_response.messages) == 1 - forget_message = get_forget_message_response.messages[0] - print(forget_message) - - assert "forgotten_by" not in forget_message + pass From 9ac0fe6a5b4a86becf496aaa179e09bd5f35cc15 Mon Sep 17 00:00:00 2001 From: mhh Date: Wed, 27 Sep 2023 18:43:42 +0200 Subject: [PATCH 09/17] rename messages fixture to aleph_messages --- tests/unit/conftest.py | 10 ++++---- tests/unit/test_node_get.py | 46 ++++++++++++++++++------------------- 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 5d6d2388..9a2a41c7 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -55,7 +55,7 @@ def json_messages(): @pytest.fixture -def messages() -> List[AlephMessage]: +def aleph_messages() -> List[AlephMessage]: return [ AggregateMessage.parse_obj( { @@ -112,11 +112,11 @@ def messages() -> List[AlephMessage]: @pytest.fixture -def raw_messages_response(messages): +def raw_messages_response(aleph_messages): return lambda page: { - "messages": [message.dict() for message in messages] if page == 1 else [], + "messages": [message.dict() for message in aleph_messages] if page == 1 else [], "pagination_item": "messages", "pagination_page": page, - "pagination_per_page": max(len(messages), 20), - "pagination_total": len(messages) if page == 1 else 0, + "pagination_per_page": max(len(aleph_messages), 20), + "pagination_total": len(aleph_messages) if page == 1 else 0, } diff --git a/tests/unit/test_node_get.py b/tests/unit/test_node_get.py index 2326803a..8052203c 100644 --- a/tests/unit/test_node_get.py +++ b/tests/unit/test_node_get.py @@ -18,20 +18,20 @@ @pytest.mark.asyncio -async def test_base(messages): +async def test_base(aleph_messages): # test add_many cache = MessageCache() - cache.add(messages) - assert len(cache) == len(messages) + cache.add(aleph_messages) + assert len(cache) == len(aleph_messages) - item_hashes = [message.item_hash for message in messages] + item_hashes = [message.item_hash for message in aleph_messages] cached_messages = cache.get(item_hashes) - assert len(cached_messages) == len(messages) + assert len(cached_messages) == len(aleph_messages) - for message in messages: + for message in aleph_messages: assert cache[message.item_hash] == message - for message in messages: + for message in aleph_messages: assert message.item_hash in cache for message in cache: @@ -47,8 +47,8 @@ class TestMessageQueries: cache: MessageCache @pytest.fixture(autouse=True) - def class_setup(self, messages): - self.messages = messages + def class_setup(self, aleph_messages): + self.messages = aleph_messages self.cache = MessageCache() self.cache.add(self.messages) @@ -127,8 +127,8 @@ class TestPostQueries: cache: MessageCache @pytest.fixture(autouse=True) - def class_setup(self, messages): - self.messages = messages + def class_setup(self, aleph_messages): + self.messages = aleph_messages self.cache = MessageCache() self.cache.add(self.messages) @@ -196,33 +196,33 @@ async def mock_message_stream(): @pytest.mark.asyncio -async def test_fetch_aggregate(messages): +async def test_fetch_aggregate(aleph_messages): cache = MessageCache() - cache.add(messages) + cache.add(aleph_messages) - aggregate = await cache.fetch_aggregate(messages[0].sender, messages[0].content.key) + aggregate = await cache.fetch_aggregate(aleph_messages[0].sender, aleph_messages[0].content.key) - assert aggregate == messages[0].content.content + assert aggregate == aleph_messages[0].content.content @pytest.mark.asyncio -async def test_fetch_aggregates(messages): +async def test_fetch_aggregates(aleph_messages): cache = MessageCache() - cache.add(messages) + cache.add(aleph_messages) - aggregates = await cache.fetch_aggregates(messages[0].sender) + aggregates = await cache.fetch_aggregates(aleph_messages[0].sender) - assert aggregates == {messages[0].content.key: messages[0].content.content} + assert aggregates == {aleph_messages[0].content.key: aleph_messages[0].content.content} @pytest.mark.asyncio -async def test_get_message(messages): +async def test_get_message(aleph_messages): cache = MessageCache() - cache.add(messages) + cache.add(aleph_messages) - message: AlephMessage = await cache.get_message(messages[0].item_hash) + message: AlephMessage = await cache.get_message(aleph_messages[0].item_hash) - assert message == messages[0] + assert message == aleph_messages[0] @pytest.mark.asyncio From a99fa66acf174d489440ab65d5cf44bba4c1bbe5 Mon Sep 17 00:00:00 2001 From: mhh Date: Wed, 27 Sep 2023 19:04:09 +0200 Subject: [PATCH 10/17] reformat black --- tests/unit/test_node_get.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_node_get.py b/tests/unit/test_node_get.py index 8052203c..7c7c7678 100644 --- a/tests/unit/test_node_get.py +++ b/tests/unit/test_node_get.py @@ -200,7 +200,9 @@ async def test_fetch_aggregate(aleph_messages): cache = MessageCache() cache.add(aleph_messages) - aggregate = await cache.fetch_aggregate(aleph_messages[0].sender, aleph_messages[0].content.key) + aggregate = await cache.fetch_aggregate( + aleph_messages[0].sender, aleph_messages[0].content.key + ) assert aggregate == aleph_messages[0].content.content @@ -212,7 +214,9 @@ async def test_fetch_aggregates(aleph_messages): aggregates = await cache.fetch_aggregates(aleph_messages[0].sender) - assert aggregates == {aleph_messages[0].content.key: aleph_messages[0].content.content} + assert aggregates == { + aleph_messages[0].content.key: aleph_messages[0].content.content + } @pytest.mark.asyncio From d8162eca59316e1db994c1a43d569420b2442147 Mon Sep 17 00:00:00 2001 From: mhh Date: Wed, 27 Sep 2023 19:07:50 +0200 Subject: [PATCH 11/17] update .gitignore --- .gitignore | 1 + cache/files/QmAndSoOn | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) delete mode 100644 cache/files/QmAndSoOn diff --git a/.gitignore b/.gitignore index c4734889..a12a6219 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ *.pot __pycache__/* .cache/* +cache/**/* .*.swp */.ipynb_checkpoints/* diff --git a/cache/files/QmAndSoOn b/cache/files/QmAndSoOn deleted file mode 100644 index d9605cba..00000000 --- a/cache/files/QmAndSoOn +++ /dev/null @@ -1 +0,0 @@ -HELLO \ No newline at end of file From e0f5944d5d3a2228c2a00c2129de3a19a2088c12 Mon Sep 17 00:00:00 2001 From: mhh Date: Wed, 27 Sep 2023 19:26:50 +0200 Subject: [PATCH 12/17] fix return type --- src/aleph/sdk/client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/aleph/sdk/client.py b/src/aleph/sdk/client.py index 4f898b72..0803b29a 100644 --- a/src/aleph/sdk/client.py +++ b/src/aleph/sdk/client.py @@ -246,7 +246,7 @@ def download_file_ipfs(self, file_hash: str) -> bytes: def download_file_to_buffer( self, file_hash: str, output_buffer: Writable[bytes] - ) -> bytes: + ) -> None: return self._wrap( self.async_session.download_file_to_buffer, file_hash=file_hash, @@ -255,7 +255,7 @@ def download_file_to_buffer( def download_file_ipfs_to_buffer( self, file_hash: str, output_buffer: Writable[bytes] - ) -> bytes: + ) -> None: return self._wrap( self.async_session.download_file_ipfs_to_buffer, file_hash=file_hash, @@ -677,7 +677,7 @@ async def download_file_ipfs_to_buffer( :param file_hash: The hash of the file to retrieve. :param output_buffer: The binary output buffer to write the file data to. """ - async with aiohttp.ClientSession() as session: + async with self.http_session as session: async with session.get( f"https://ipfs.aleph.im/ipfs/{file_hash}" ) as response: From 4386802fe6df440a731cfefe9ea692398405f15d Mon Sep 17 00:00:00 2001 From: mhh Date: Wed, 27 Sep 2023 19:39:40 +0200 Subject: [PATCH 13/17] use new message_types parameter for DomainNode initialization --- src/aleph/sdk/base.py | 6 +++--- src/aleph/sdk/node/__init__.py | 18 +++++++++--------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/aleph/sdk/base.py b/src/aleph/sdk/base.py index a5b2c266..db9d4a02 100644 --- a/src/aleph/sdk/base.py +++ b/src/aleph/sdk/base.py @@ -204,7 +204,7 @@ async def get_messages( async def get_messages_iterator( self, - message_type: Optional[MessageType] = None, + message_types: Optional[Iterable[MessageType]] = None, content_types: Optional[Iterable[str]] = None, content_keys: Optional[Iterable[str]] = None, refs: Optional[Iterable[str]] = None, @@ -220,7 +220,7 @@ async def get_messages_iterator( Fetch all filtered messages, returning an async iterator and fetching them page by page. Might return duplicates but will always return all messages. - :param message_type: Filter by message type, can be "AGGREGATE", "POST", "PROGRAM", "VM", "STORE" or "FORGET" + :param message_types: Filter by message type, can be "AGGREGATE", "POST", "PROGRAM", "VM", "STORE" or "FORGET" :param content_types: Filter by content type :param content_keys: Filter by content key :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) @@ -237,7 +237,7 @@ async def get_messages_iterator( while resp is None or len(resp.messages) > 0: resp = await self.get_messages( page=page, - message_type=message_type, + message_types=message_types, content_types=content_types, content_keys=content_keys, refs=refs, diff --git a/src/aleph/sdk/node/__init__.py b/src/aleph/sdk/node/__init__.py index f95256cc..d95fd01d 100644 --- a/src/aleph/sdk/node/__init__.py +++ b/src/aleph/sdk/node/__init__.py @@ -359,9 +359,9 @@ class DomainNode(MessageCache, BaseAuthenticatedAlephClient): A Domain Node is a queryable proxy for Aleph Messages that are stored in a database cache and/or in the Aleph network. - It synchronizes with the network on a subset of the messages by listening to the network and storing the - messages in the cache. The user may define the subset by specifying a channels, tags, senders, chains, - message types, and/or a time window. + It synchronizes with the network on a subset of the messages (the "domain") by listening to the network and storing the + messages in the cache. The user may define the domain by specifying a channels, tags, senders, chains and/or + message types. """ def __init__( @@ -371,7 +371,7 @@ def __init__( tags: Optional[Iterable[str]] = None, addresses: Optional[Iterable[str]] = None, chains: Optional[Iterable[Chain]] = None, - message_type: Optional[MessageType] = None, + message_types: Optional[Iterable[MessageType]] = None, ): super().__init__() self.session = session @@ -379,7 +379,7 @@ def __init__( self.tags = tags self.addresses = addresses self.chains = chains - self.message_type = message_type + self.message_types = message_types # start listening to the network and storing messages in the cache asyncio.get_event_loop().create_task( @@ -389,7 +389,7 @@ def __init__( tags=self.tags, addresses=self.addresses, chains=self.chains, - message_type=self.message_type, + message_types=self.message_types, ) ) ) @@ -401,7 +401,7 @@ def __init__( tags=self.tags, addresses=self.addresses, chains=self.chains, - message_type=self.message_type, + message_types=self.message_types, ) ) @@ -417,7 +417,7 @@ async def synchronize( tags: Optional[Iterable[str]] = None, addresses: Optional[Iterable[str]] = None, chains: Optional[Iterable[Chain]] = None, - message_type: Optional[MessageType] = None, + message_types: Optional[Iterable[MessageType]] = None, start_date: Optional[Union[datetime, float]] = None, end_date: Optional[Union[datetime, float]] = None, ): @@ -431,7 +431,7 @@ async def synchronize( tags=tags, addresses=addresses, chains=chains, - message_type=message_type, + message_types=message_types, start_date=start_date, end_date=end_date, ): From 5559a3ef1cb0987e21555982edf892e5cd42be14 Mon Sep 17 00:00:00 2001 From: mhh Date: Wed, 27 Sep 2023 20:04:58 +0200 Subject: [PATCH 14/17] revert change to download_file_ipfs_to_buffer --- src/aleph/sdk/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aleph/sdk/client.py b/src/aleph/sdk/client.py index 0803b29a..620391d8 100644 --- a/src/aleph/sdk/client.py +++ b/src/aleph/sdk/client.py @@ -677,7 +677,7 @@ async def download_file_ipfs_to_buffer( :param file_hash: The hash of the file to retrieve. :param output_buffer: The binary output buffer to write the file data to. """ - async with self.http_session as session: + async with aiohttp.ClientSession() as session: async with session.get( f"https://ipfs.aleph.im/ipfs/{file_hash}" ) as response: From 8d6225e8293dcbba2716a1c8547228c29a1eeb2c Mon Sep 17 00:00:00 2001 From: mhh Date: Wed, 27 Sep 2023 20:06:18 +0200 Subject: [PATCH 15/17] add restriction to posted messages not being in domain; add default account and chain of used Account to domain --- src/aleph/sdk/node/__init__.py | 45 ++++++++++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 2 deletions(-) diff --git a/src/aleph/sdk/node/__init__.py b/src/aleph/sdk/node/__init__.py index d95fd01d..b6b8f1e8 100644 --- a/src/aleph/sdk/node/__init__.py +++ b/src/aleph/sdk/node/__init__.py @@ -377,8 +377,16 @@ def __init__( self.session = session self.channels = channels self.tags = tags - self.addresses = addresses - self.chains = chains + self.addresses = ( + list(addresses) + [session.account.get_address()] + if addresses + else [session.account.get_address()] + ) + self.chains = ( + list(chains) + [Chain(session.account.CHAIN)] + if chains + else [session.account.CHAIN] + ) self.message_types = message_types # start listening to the network and storing messages in the cache @@ -460,6 +468,34 @@ async def download_file(self, file_hash: str) -> bytes: def _file_path(file_hash: str) -> Path: return settings.CACHE_FILES_PATH / Path(file_hash) + def check_validity( + self, + message_type: MessageType, + address: Optional[str] = None, + channel: Optional[str] = None, + content: Optional[Dict] = None, + ): + if self.message_types and message_type not in self.message_types: + raise ValueError( + f"Cannot create {message_type.value} message because DomainNode is not listening to post messages." + ) + if address and self.addresses and address not in self.addresses: + raise ValueError( + f"Cannot create {message_type.value} message because DomainNode is not listening to messages from address {address}." + ) + if self.channels and channel not in self.channels: + raise ValueError( + f"Cannot create {message_type.value} message because DomainNode is not listening to messages from channel {channel}." + ) + if ( + content + and self.tags + and not set(content.get("tags", [])).intersection(self.tags) + ): + raise ValueError( + f"Cannot create {message_type.value} message because DomainNode is not listening to any of these tags: {content.get('tags', [])}." + ) + async def create_post( self, post_content: Any, @@ -471,6 +507,7 @@ async def create_post( storage_engine: StorageEnum = StorageEnum.storage, sync: bool = False, ) -> Tuple[AlephMessage, MessageStatus]: + self.check_validity(MessageType.post, address, channel, post_content) resp, status = await self.session.create_post( post_content=post_content, post_type=post_type, @@ -495,6 +532,7 @@ async def create_aggregate( inline: bool = True, sync: bool = False, ) -> Tuple[AlephMessage, MessageStatus]: + self.check_validity(MessageType.aggregate, address, channel) resp, status = await self.session.create_aggregate( key=key, content=content, @@ -520,6 +558,7 @@ async def create_store( channel: Optional[str] = None, sync: bool = False, ) -> Tuple[AlephMessage, MessageStatus]: + self.check_validity(MessageType.store, address, channel, extra_fields) resp, status = await self.session.create_store( address=address, file_content=file_content, @@ -555,6 +594,7 @@ async def create_program( subscriptions: Optional[List[Mapping]] = None, metadata: Optional[Mapping[str, Any]] = None, ) -> Tuple[AlephMessage, MessageStatus]: + self.check_validity(MessageType.program, address, channel, metadata) resp, status = await self.session.create_program( program_ref=program_ref, entrypoint=entrypoint, @@ -586,6 +626,7 @@ async def forget( address: Optional[str] = None, sync: bool = False, ) -> Tuple[AlephMessage, MessageStatus]: + self.check_validity(MessageType.forget, address, channel) resp, status = await self.session.forget( hashes=hashes, reason=reason, From c60c28760f5086a0c12362b22aab64abe2d5a9ec Mon Sep 17 00:00:00 2001 From: mhh Date: Wed, 27 Sep 2023 20:14:53 +0200 Subject: [PATCH 16/17] fix typing issue --- src/aleph/sdk/node/__init__.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/aleph/sdk/node/__init__.py b/src/aleph/sdk/node/__init__.py index b6b8f1e8..42ac739b 100644 --- a/src/aleph/sdk/node/__init__.py +++ b/src/aleph/sdk/node/__init__.py @@ -483,7 +483,7 @@ def check_validity( raise ValueError( f"Cannot create {message_type.value} message because DomainNode is not listening to messages from address {address}." ) - if self.channels and channel not in self.channels: + if channel and self.channels and channel not in self.channels: raise ValueError( f"Cannot create {message_type.value} message because DomainNode is not listening to messages from channel {channel}." ) @@ -594,7 +594,9 @@ async def create_program( subscriptions: Optional[List[Mapping]] = None, metadata: Optional[Mapping[str, Any]] = None, ) -> Tuple[AlephMessage, MessageStatus]: - self.check_validity(MessageType.program, address, channel, metadata) + self.check_validity( + MessageType.program, address, channel, dict(metadata) if metadata else None + ) resp, status = await self.session.create_program( program_ref=program_ref, entrypoint=entrypoint, From cd819180a4d5e3563626c756433291a24f0ea705 Mon Sep 17 00:00:00 2001 From: mhh Date: Thu, 28 Sep 2023 19:09:50 +0200 Subject: [PATCH 17/17] Refactor/Breaking Changes: Add MessageFilter and PostFilter to client methods; refactor data model into new modules; fix tests and edge cases in DomainNode; improve mocking of HTTP/WS connection; fix caching of amend messages --- src/aleph/sdk/base.py | 144 +--------- src/aleph/sdk/client.py | 228 ++------------- src/aleph/sdk/models.py | 56 ---- src/aleph/sdk/models/__init__.py | 0 src/aleph/sdk/models/common.py | 39 +++ src/aleph/sdk/models/db/__init__.py | 0 src/aleph/sdk/{node => models/db}/common.py | 0 src/aleph/sdk/models/db/message.py | 36 +++ src/aleph/sdk/models/db/post.py | 25 ++ src/aleph/sdk/models/message.py | 190 +++++++++++++ src/aleph/sdk/models/post.py | 170 +++++++++++ src/aleph/sdk/{node/__init__.py => node.py} | 295 +++++++++----------- src/aleph/sdk/node/message.py | 137 --------- src/aleph/sdk/node/post.py | 115 -------- tests/unit/conftest.py | 10 +- tests/unit/test_asynchronous_get.py | 15 +- tests/unit/test_node.py | 81 +++++- tests/unit/test_node_get.py | 150 +++++++--- tests/unit/test_synchronous_get.py | 6 +- 19 files changed, 842 insertions(+), 855 deletions(-) delete mode 100644 src/aleph/sdk/models.py create mode 100644 src/aleph/sdk/models/__init__.py create mode 100644 src/aleph/sdk/models/common.py create mode 100644 src/aleph/sdk/models/db/__init__.py rename src/aleph/sdk/{node => models/db}/common.py (100%) create mode 100644 src/aleph/sdk/models/db/message.py create mode 100644 src/aleph/sdk/models/db/post.py create mode 100644 src/aleph/sdk/models/message.py create mode 100644 src/aleph/sdk/models/post.py rename src/aleph/sdk/{node/__init__.py => node.py} (68%) delete mode 100644 src/aleph/sdk/node/message.py delete mode 100644 src/aleph/sdk/node/post.py diff --git a/src/aleph/sdk/base.py b/src/aleph/sdk/base.py index db9d4a02..ea3ac9b3 100644 --- a/src/aleph/sdk/base.py +++ b/src/aleph/sdk/base.py @@ -2,7 +2,6 @@ import logging from abc import ABC, abstractmethod -from datetime import datetime from pathlib import Path from typing import ( Any, @@ -26,8 +25,9 @@ from aleph_message.models.execution.program import Encoding from aleph_message.status import MessageStatus -from aleph.sdk.models import PostsResponse -from aleph.sdk.types import GenericMessage, StorageEnum +from .models.message import MessageFilter +from .models.post import PostFilter, PostsResponse +from .types import GenericMessage, StorageEnum DEFAULT_PAGE_SIZE = 200 @@ -70,15 +70,7 @@ async def get_posts( self, pagination: int = DEFAULT_PAGE_SIZE, page: int = 1, - types: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + post_filter: Optional[PostFilter] = None, ignore_invalid_messages: Optional[bool] = True, invalid_messages_log_level: Optional[int] = logging.NOTSET, ) -> PostsResponse: @@ -87,15 +79,7 @@ async def get_posts( :param pagination: Number of items to fetch (Default: 200) :param page: Page to fetch, begins at 1 (Default: 1) - :param types: Types of posts to fetch (Default: all types) - :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) - :param addresses: Addresses of the posts to fetch (Default: all addresses) - :param tags: Tags of the posts to fetch (Default: all tags) - :param hashes: Specific item_hashes to fetch - :param channels: Channels of the posts to fetch (Default: all channels) - :param chains: Chains of the posts to fetch (Default: all chains) - :param start_date: Earliest date to fetch messages from - :param end_date: Latest date to fetch messages from + :param post_filter: Filter to apply to the posts (Default: None) :param ignore_invalid_messages: Ignore invalid messages (Default: True) :param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET) """ @@ -103,44 +87,20 @@ async def get_posts( async def get_posts_iterator( self, - types: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + post_filter: Optional[PostFilter] = None, ) -> AsyncIterable[PostMessage]: """ Fetch all filtered posts, returning an async iterator and fetching them page by page. Might return duplicates but will always return all posts. - :param types: Types of posts to fetch (Default: all types) - :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) - :param addresses: Addresses of the posts to fetch (Default: all addresses) - :param tags: Tags of the posts to fetch (Default: all tags) - :param hashes: Specific item_hashes to fetch - :param channels: Channels of the posts to fetch (Default: all channels) - :param chains: Chains of the posts to fetch (Default: all chains) - :param start_date: Earliest date to fetch messages from - :param end_date: Latest date to fetch messages from + :param post_filter: Filter to apply to the posts (Default: None) """ page = 1 resp = None while resp is None or len(resp.posts) > 0: resp = await self.get_posts( page=page, - types=types, - refs=refs, - addresses=addresses, - tags=tags, - hashes=hashes, - channels=channels, - chains=chains, - start_date=start_date, - end_date=end_date, + post_filter=post_filter, ) page += 1 for post in resp.posts: @@ -165,18 +125,7 @@ async def get_messages( self, pagination: int = DEFAULT_PAGE_SIZE, page: int = 1, - message_type: Optional[MessageType] = None, - message_types: Optional[Iterable[MessageType]] = None, - content_types: Optional[Iterable[str]] = None, - content_keys: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + message_filter: Optional[MessageFilter] = None, ignore_invalid_messages: Optional[bool] = True, invalid_messages_log_level: Optional[int] = logging.NOTSET, ) -> MessagesResponse: @@ -185,18 +134,7 @@ async def get_messages( :param pagination: Number of items to fetch (Default: 200) :param page: Page to fetch, begins at 1 (Default: 1) - :param message_type: [DEPRECATED] Filter by message type, can be "AGGREGATE", "POST", "PROGRAM", "VM", "STORE" or "FORGET" - :param message_types: Filter by message types, can be any combination of "AGGREGATE", "POST", "PROGRAM", "VM", "STORE" or "FORGET" - :param content_types: Filter by content type - :param content_keys: Filter by aggregate key - :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) - :param addresses: Addresses of the posts to fetch (Default: all addresses) - :param tags: Tags of the posts to fetch (Default: all tags) - :param hashes: Specific item_hashes to fetch - :param channels: Channels of the posts to fetch (Default: all channels) - :param chains: Filter by sender address chain - :param start_date: Earliest date to fetch messages from - :param end_date: Latest date to fetch messages from + :param message_filter: Filter to apply to the messages :param ignore_invalid_messages: Ignore invalid messages (Default: True) :param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET) """ @@ -204,50 +142,20 @@ async def get_messages( async def get_messages_iterator( self, - message_types: Optional[Iterable[MessageType]] = None, - content_types: Optional[Iterable[str]] = None, - content_keys: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + message_filter: Optional[MessageFilter] = None, ) -> AsyncIterable[AlephMessage]: """ Fetch all filtered messages, returning an async iterator and fetching them page by page. Might return duplicates but will always return all messages. - :param message_types: Filter by message type, can be "AGGREGATE", "POST", "PROGRAM", "VM", "STORE" or "FORGET" - :param content_types: Filter by content type - :param content_keys: Filter by content key - :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) - :param addresses: Addresses of the posts to fetch (Default: all addresses) - :param tags: Tags of the posts to fetch (Default: all tags) - :param hashes: Specific item_hashes to fetch - :param channels: Channels of the posts to fetch (Default: all channels) - :param chains: Filter by sender address chain - :param start_date: Earliest date to fetch messages from - :param end_date: Latest date to fetch messages from + :param message_filter: Filter to apply to the messages """ page = 1 resp = None while resp is None or len(resp.messages) > 0: resp = await self.get_messages( page=page, - message_types=message_types, - content_types=content_types, - content_keys=content_keys, - refs=refs, - addresses=addresses, - tags=tags, - hashes=hashes, - channels=channels, - chains=chains, - start_date=start_date, - end_date=end_date, + message_filter=message_filter, ) page += 1 for message in resp.messages: @@ -272,34 +180,12 @@ async def get_message( @abstractmethod def watch_messages( self, - message_type: Optional[MessageType] = None, - message_types: Optional[Iterable[MessageType]] = None, - content_types: Optional[Iterable[str]] = None, - content_keys: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + message_filter: Optional[MessageFilter] = None, ) -> AsyncIterable[AlephMessage]: """ Iterate over current and future matching messages asynchronously. - :param message_type: [DEPRECATED] Type of message to watch - :param message_types: Types of messages to watch - :param content_types: Content types to watch - :param content_keys: Filter by aggregate key - :param refs: References to watch - :param addresses: Addresses to watch - :param tags: Tags to watch - :param hashes: Hashes to watch - :param channels: Channels to watch - :param chains: Chains to watch - :param start_date: Start date from when to watch - :param end_date: End date until when to watch + :param message_filter: Filter to apply to the messages """ pass diff --git a/src/aleph/sdk/client.py b/src/aleph/sdk/client.py index 620391d8..ac9bee80 100644 --- a/src/aleph/sdk/client.py +++ b/src/aleph/sdk/client.py @@ -5,8 +5,6 @@ import queue import threading import time -import warnings -from datetime import datetime from io import BytesIO from pathlib import Path from typing import ( @@ -61,7 +59,8 @@ MessageNotFoundError, MultipleMessagesError, ) -from .models import MessagesResponse, Post, PostsResponse +from .models.message import MessageFilter, MessagesResponse +from .models.post import Post, PostFilter, PostsResponse from .utils import check_unix_socket_valid, get_message_type_value logger = logging.getLogger(__name__) @@ -141,18 +140,7 @@ def get_messages( self, pagination: int = 200, page: int = 1, - message_type: Optional[MessageType] = None, - message_types: Optional[List[MessageType]] = None, - content_types: Optional[Iterable[str]] = None, - content_keys: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + message_filter: Optional[MessageFilter] = None, ignore_invalid_messages: bool = True, invalid_messages_log_level: int = logging.NOTSET, ) -> MessagesResponse: @@ -160,18 +148,7 @@ def get_messages( self.async_session.get_messages, pagination=pagination, page=page, - message_type=message_type, - message_types=message_types, - content_types=content_types, - content_keys=content_keys, - refs=refs, - addresses=addresses, - tags=tags, - hashes=hashes, - channels=channels, - chains=chains, - start_date=start_date, - end_date=end_date, + message_filter=message_filter, ignore_invalid_messages=ignore_invalid_messages, invalid_messages_log_level=invalid_messages_log_level, ) @@ -210,29 +187,13 @@ def get_posts( self, pagination: int = 200, page: int = 1, - types: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + post_filter: Optional[PostFilter] = None, ) -> PostsResponse: return self._wrap( self.async_session.get_posts, pagination=pagination, page=page, - types=types, - refs=refs, - addresses=addresses, - tags=tags, - hashes=hashes, - channels=channels, - chains=chains, - start_date=start_date, - end_date=end_date, + post_filter=post_filter, ) def download_file(self, file_hash: str) -> bytes: @@ -264,16 +225,7 @@ def download_file_ipfs_to_buffer( def watch_messages( self, - message_type: Optional[MessageType] = None, - content_types: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + message_filter: Optional[MessageFilter] = None, ) -> Iterable[AlephMessage]: """ Iterate over current and future matching messages synchronously. @@ -286,18 +238,7 @@ def watch_messages( args=( output_queue, self.async_session.api_server, - ( - message_type, - content_types, - refs, - addresses, - tags, - hashes, - channels, - chains, - start_date, - end_date, - ), + (message_filter), {}, ), ) @@ -570,15 +511,7 @@ async def get_posts( self, pagination: int = 200, page: int = 1, - types: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + post_filter: Optional[PostFilter] = None, ignore_invalid_messages: Optional[bool] = True, invalid_messages_log_level: Optional[int] = logging.NOTSET, ) -> PostsResponse: @@ -591,31 +524,11 @@ async def get_posts( else invalid_messages_log_level ) - params: Dict[str, Any] = dict(pagination=pagination, page=page) - - if types is not None: - params["types"] = ",".join(types) - if refs is not None: - params["refs"] = ",".join(refs) - if addresses is not None: - params["addresses"] = ",".join(addresses) - if tags is not None: - params["tags"] = ",".join(tags) - if hashes is not None: - params["hashes"] = ",".join(hashes) - if channels is not None: - params["channels"] = ",".join(channels) - if chains is not None: - params["chains"] = ",".join(chains) - - if start_date is not None: - if not isinstance(start_date, float) and hasattr(start_date, "timestamp"): - start_date = start_date.timestamp() - params["startDate"] = start_date - if end_date is not None: - if not isinstance(end_date, float) and hasattr(start_date, "timestamp"): - end_date = end_date.timestamp() - params["endDate"] = end_date + if not post_filter: + post_filter = PostFilter() + params = post_filter.as_http_params() + params["page"] = str(page) + params["pagination"] = str(pagination) async with self.http_session.get("/api/v1/posts.json", params=params) as resp: resp.raise_for_status() @@ -722,18 +635,7 @@ async def get_messages( self, pagination: int = 200, page: int = 1, - message_type: Optional[MessageType] = None, - message_types: Optional[Iterable[MessageType]] = None, - content_types: Optional[Iterable[str]] = None, - content_keys: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + message_filter: Optional[MessageFilter] = None, ignore_invalid_messages: Optional[bool] = True, invalid_messages_log_level: Optional[int] = logging.NOTSET, ) -> MessagesResponse: @@ -746,43 +648,11 @@ async def get_messages( else invalid_messages_log_level ) - params: Dict[str, Any] = dict(pagination=pagination, page=page) - - if message_type is not None: - warnings.warn( - "The message_type parameter is deprecated, please use message_types instead.", - DeprecationWarning, - ) - params["msgType"] = message_type.value - if message_types is not None: - params["msgTypes"] = ",".join([t.value for t in message_types]) - print(params["msgTypes"]) - if content_types is not None: - params["contentTypes"] = ",".join(content_types) - if content_keys is not None: - params["contentKeys"] = ",".join(content_keys) - if refs is not None: - params["refs"] = ",".join(refs) - if addresses is not None: - params["addresses"] = ",".join(addresses) - if tags is not None: - params["tags"] = ",".join(tags) - if hashes is not None: - params["hashes"] = ",".join(hashes) - if channels is not None: - params["channels"] = ",".join(channels) - if chains is not None: - params["chains"] = ",".join(chains) - - if start_date is not None: - if not isinstance(start_date, float) and hasattr(start_date, "timestamp"): - start_date = start_date.timestamp() - params["startDate"] = start_date - if end_date is not None: - if not isinstance(end_date, float) and hasattr(start_date, "timestamp"): - end_date = end_date.timestamp() - params["endDate"] = end_date - + if not message_filter: + message_filter = MessageFilter() + params = message_filter.as_http_params() + params["page"] = str(page) + params["pagination"] = str(pagination) async with self.http_session.get( "/api/v0/messages.json", params=params ) as resp: @@ -825,8 +695,10 @@ async def get_message( channel: Optional[str] = None, ) -> GenericMessage: messages_response = await self.get_messages( - hashes=[item_hash], - channels=[channel] if channel else None, + message_filter=MessageFilter( + hashes=[item_hash], + channels=[channel] if channel else None, + ) ) if len(messages_response.messages) < 1: raise MessageNotFoundError(f"No such hash {item_hash}") @@ -846,54 +718,11 @@ async def get_message( async def watch_messages( self, - message_type: Optional[MessageType] = None, - message_types: Optional[Iterable[MessageType]] = None, - content_types: Optional[Iterable[str]] = None, - content_keys: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + message_filter: Optional[MessageFilter] = None, ) -> AsyncIterable[AlephMessage]: - params: Dict[str, Any] = dict() - - if message_type is not None: - warnings.warn( - "The message_type parameter is deprecated, please use message_types instead.", - DeprecationWarning, - ) - params["msgType"] = message_type.value - if message_types is not None: - params["msgTypes"] = ",".join([t.value for t in message_types]) - if content_types is not None: - params["contentTypes"] = ",".join(content_types) - if content_keys is not None: - params["contentKeys"] = ",".join(content_keys) - if refs is not None: - params["refs"] = ",".join(refs) - if addresses is not None: - params["addresses"] = ",".join(addresses) - if tags is not None: - params["tags"] = ",".join(tags) - if hashes is not None: - params["hashes"] = ",".join(hashes) - if channels is not None: - params["channels"] = ",".join(channels) - if chains is not None: - params["chains"] = ",".join(chains) - - if start_date is not None: - if not isinstance(start_date, float) and hasattr(start_date, "timestamp"): - start_date = start_date.timestamp() - params["startDate"] = start_date - if end_date is not None: - if not isinstance(end_date, float) and hasattr(start_date, "timestamp"): - end_date = end_date.timestamp() - params["endDate"] = end_date + if not message_filter: + message_filter = MessageFilter() + params = message_filter.as_http_params() async with self.http_session.ws_connect( "/api/ws0/messages", params=params @@ -1387,6 +1216,7 @@ async def _prepare_aleph_message( if allow_inlining and (len(item_content) < settings.MAX_INLINE_SIZE): message_dict["item_content"] = item_content + print(item_content) message_dict["item_hash"] = self.compute_sha256(item_content) message_dict["item_type"] = ItemType.inline else: diff --git a/src/aleph/sdk/models.py b/src/aleph/sdk/models.py deleted file mode 100644 index 40083bb2..00000000 --- a/src/aleph/sdk/models.py +++ /dev/null @@ -1,56 +0,0 @@ -from datetime import datetime -from typing import Any, Dict, List, Optional - -from aleph_message.models import AlephMessage, ItemHash -from pydantic import BaseModel, Field - - -class PaginationResponse(BaseModel): - pagination_page: int - pagination_total: int - pagination_per_page: int - pagination_item: str - - -class MessagesResponse(PaginationResponse): - """Response from an Aleph node API on the path /api/v0/messages.json""" - - messages: List[AlephMessage] - pagination_item = "messages" - - -class Post(BaseModel): - """ - A post is a type of message that can be updated. Over the get_posts API - we get the latest version of a post. - """ - - item_hash: ItemHash = Field(description="Hash of the content (sha256 by default)") - content: Dict[str, Any] = Field( - description="The content.content of the POST message" - ) - original_item_hash: ItemHash = Field( - description="Hash of the original content (sha256 by default)" - ) - original_type: str = Field( - description="The original, user-generated 'content-type' of the POST message" - ) - address: str = Field(description="The address of the sender of the POST message") - ref: Optional[str] = Field(description="Other message referenced by this one") - channel: Optional[str] = Field( - description="The channel where the POST message was published" - ) - created: datetime = Field(description="The time when the POST message was created") - last_updated: datetime = Field( - description="The time when the POST message was last updated" - ) - - class Config: - allow_extra = False - - -class PostsResponse(PaginationResponse): - """Response from an Aleph node API on the path /api/v0/posts.json""" - - posts: List[Post] - pagination_item = "posts" diff --git a/src/aleph/sdk/models/__init__.py b/src/aleph/sdk/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/aleph/sdk/models/common.py b/src/aleph/sdk/models/common.py new file mode 100644 index 00000000..bb261683 --- /dev/null +++ b/src/aleph/sdk/models/common.py @@ -0,0 +1,39 @@ +from datetime import datetime +from typing import Iterable, Optional, Type, Union + +from peewee import Model +from pydantic import BaseModel + + +class PaginationResponse(BaseModel): + pagination_page: int + pagination_total: int + pagination_per_page: int + pagination_item: str + + +def serialize_list(values: Optional[Iterable[str]]) -> Optional[str]: + if values: + return ",".join(values) + else: + return None + + +def _date_field_to_float(date: Optional[Union[datetime, float]]) -> Optional[float]: + if date is None: + return None + elif isinstance(date, float): + return date + elif hasattr(date, "timestamp"): + return date.timestamp() + else: + raise TypeError(f"Invalid type: `{type(date)}`") + + +def query_db_field(db_model: Type[Model], field_name: str, field_values: Iterable[str]): + field = getattr(db_model, field_name) + values = list(field_values) + + if len(values) == 1: + return field == values[0] + return field.in_(values) diff --git a/src/aleph/sdk/models/db/__init__.py b/src/aleph/sdk/models/db/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/aleph/sdk/node/common.py b/src/aleph/sdk/models/db/common.py similarity index 100% rename from src/aleph/sdk/node/common.py rename to src/aleph/sdk/models/db/common.py diff --git a/src/aleph/sdk/models/db/message.py b/src/aleph/sdk/models/db/message.py new file mode 100644 index 00000000..f53eb676 --- /dev/null +++ b/src/aleph/sdk/models/db/message.py @@ -0,0 +1,36 @@ +from aleph_message.models import MessageConfirmation +from peewee import BooleanField, CharField, FloatField, IntegerField, Model +from playhouse.sqlite_ext import JSONField + +from .common import PydanticField, db, pydantic_json_dumps + + +class MessageDBModel(Model): + """ + A simple database model for storing AlephMessage objects. + """ + + item_hash = CharField(primary_key=True) + chain = CharField(5) + type = CharField(9) + sender = CharField() + channel = CharField(null=True) + confirmations: PydanticField[MessageConfirmation] = PydanticField( + type=MessageConfirmation, null=True + ) + confirmed = BooleanField(null=True) + signature = CharField(null=True) + size = IntegerField(null=True) + time = FloatField() + item_type = CharField(7) + item_content = CharField(null=True) + hash_type = CharField(6, null=True) + content = JSONField(json_dumps=pydantic_json_dumps) + forgotten_by = CharField(null=True) + tags = JSONField(json_dumps=pydantic_json_dumps, null=True) + key = CharField(null=True) + ref = CharField(null=True) + content_type = CharField(null=True) + + class Meta: + database = db diff --git a/src/aleph/sdk/models/db/post.py b/src/aleph/sdk/models/db/post.py new file mode 100644 index 00000000..7f634d54 --- /dev/null +++ b/src/aleph/sdk/models/db/post.py @@ -0,0 +1,25 @@ +from peewee import CharField, DateTimeField, Model +from playhouse.sqlite_ext import JSONField + +from .common import db, pydantic_json_dumps + + +class PostDBModel(Model): + """ + A simple database model for storing AlephMessage objects. + """ + + original_item_hash = CharField(primary_key=True) + item_hash = CharField() + content = JSONField(json_dumps=pydantic_json_dumps) + original_type = CharField() + address = CharField() + ref = CharField(null=True) + channel = CharField(null=True) + created = DateTimeField() + last_updated = DateTimeField() + tags = JSONField(json_dumps=pydantic_json_dumps, null=True) + chain = CharField(5) + + class Meta: + database = db diff --git a/src/aleph/sdk/models/message.py b/src/aleph/sdk/models/message.py new file mode 100644 index 00000000..f695e883 --- /dev/null +++ b/src/aleph/sdk/models/message.py @@ -0,0 +1,190 @@ +from datetime import datetime +from typing import Any, Dict, Iterable, List, Optional, Union + +from aleph_message import parse_message +from aleph_message.models import AlephMessage, MessageType +from playhouse.shortcuts import model_to_dict + +from .common import ( + PaginationResponse, + _date_field_to_float, + query_db_field, + serialize_list, +) +from .db.message import MessageDBModel + + +class MessagesResponse(PaginationResponse): + """Response from an Aleph node API on the path /api/v0/messages.json""" + + messages: List[AlephMessage] + pagination_item = "messages" + + +class MessageFilter: + """ + A collection of filters that can be applied on message queries. + :param message_types: Filter by message type + :param content_types: Filter by content type + :param content_keys: Filter by content key + :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) + :param addresses: Addresses of the posts to fetch (Default: all addresses) + :param tags: Tags of the posts to fetch (Default: all tags) + :param hashes: Specific item_hashes to fetch + :param channels: Channels of the posts to fetch (Default: all channels) + :param chains: Filter by sender address chain + :param start_date: Earliest date to fetch messages from + :param end_date: Latest date to fetch messages from + """ + + message_types: Optional[Iterable[MessageType]] + content_types: Optional[Iterable[str]] + content_keys: Optional[Iterable[str]] + refs: Optional[Iterable[str]] + addresses: Optional[Iterable[str]] + tags: Optional[Iterable[str]] + hashes: Optional[Iterable[str]] + channels: Optional[Iterable[str]] + chains: Optional[Iterable[str]] + start_date: Optional[Union[datetime, float]] + end_date: Optional[Union[datetime, float]] + + def __init__( + self, + message_types: Optional[Iterable[MessageType]] = None, + content_types: Optional[Iterable[str]] = None, + content_keys: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ): + self.message_types = message_types + self.content_types = content_types + self.content_keys = content_keys + self.refs = refs + self.addresses = addresses + self.tags = tags + self.hashes = hashes + self.channels = channels + self.chains = chains + self.start_date = start_date + self.end_date = end_date + + def as_http_params(self) -> Dict[str, str]: + """Convert the filters into a dict that can be used by an `aiohttp` client + as `params` to build the HTTP query string. + """ + + partial_result = { + "msgType": serialize_list( + [type.value for type in self.message_types] + if self.message_types + else None + ), + "contentTypes": serialize_list(self.content_types), + "contentKeys": serialize_list(self.content_keys), + "refs": serialize_list(self.refs), + "addresses": serialize_list(self.addresses), + "tags": serialize_list(self.tags), + "hashes": serialize_list(self.hashes), + "channels": serialize_list(self.channels), + "chains": serialize_list(self.chains), + "startDate": _date_field_to_float(self.start_date), + "endDate": _date_field_to_float(self.end_date), + } + + # Ensure all values are strings. + result: Dict[str, str] = {} + + # Drop empty values + for key, value in partial_result.items(): + if value: + assert isinstance(value, str), f"Value must be a string: `{value}`" + result[key] = value + + return result + + def as_db_query(self): + query = MessageDBModel.select().order_by(MessageDBModel.time.desc()) + conditions = [] + if self.message_types: + conditions.append( + query_db_field( + MessageDBModel, "type", [type.value for type in self.message_types] + ) + ) + if self.content_keys: + conditions.append(query_db_field(MessageDBModel, "key", self.content_keys)) + if self.content_types: + conditions.append( + query_db_field(MessageDBModel, "content_type", self.content_types) + ) + if self.refs: + conditions.append(query_db_field(MessageDBModel, "ref", self.refs)) + if self.addresses: + conditions.append(query_db_field(MessageDBModel, "sender", self.addresses)) + if self.tags: + for tag in self.tags: + conditions.append(MessageDBModel.tags.contains(tag)) + if self.hashes: + conditions.append(query_db_field(MessageDBModel, "item_hash", self.hashes)) + if self.channels: + conditions.append(query_db_field(MessageDBModel, "channel", self.channels)) + if self.chains: + conditions.append(query_db_field(MessageDBModel, "chain", self.chains)) + if self.start_date: + conditions.append(MessageDBModel.time >= self.start_date) + if self.end_date: + conditions.append(MessageDBModel.time <= self.end_date) + + if conditions: + query = query.where(*conditions) + return query + + +def message_to_model(message: AlephMessage) -> Dict: + return { + "item_hash": str(message.item_hash), + "chain": message.chain, + "type": message.type, + "sender": message.sender, + "channel": message.channel, + "confirmations": message.confirmations[0] if message.confirmations else None, + "confirmed": message.confirmed, + "signature": message.signature, + "size": message.size, + "time": message.time, + "item_type": message.item_type, + "item_content": message.item_content, + "hash_type": message.hash_type, + "content": message.content, + "forgotten_by": message.forgotten_by[0] if message.forgotten_by else None, + "tags": message.content.content.get("tags", None) + if hasattr(message.content, "content") + else None, + "key": message.content.key if hasattr(message.content, "key") else None, + "ref": message.content.ref if hasattr(message.content, "ref") else None, + "content_type": message.content.type + if hasattr(message.content, "type") + else None, + } + + +def model_to_message(item: Any) -> AlephMessage: + item.confirmations = [item.confirmations] if item.confirmations else [] + item.forgotten_by = [item.forgotten_by] if item.forgotten_by else None + + to_exclude = [ + MessageDBModel.tags, + MessageDBModel.ref, + MessageDBModel.key, + MessageDBModel.content_type, + ] + + item_dict = model_to_dict(item, exclude=to_exclude) + return parse_message(item_dict) diff --git a/src/aleph/sdk/models/post.py b/src/aleph/sdk/models/post.py new file mode 100644 index 00000000..b0c4445d --- /dev/null +++ b/src/aleph/sdk/models/post.py @@ -0,0 +1,170 @@ +from datetime import datetime +from typing import Any, Dict, Iterable, List, Optional, Union + +from aleph_message.models import ItemHash, PostMessage +from playhouse.shortcuts import model_to_dict +from pydantic import BaseModel, Field + +from .common import ( + PaginationResponse, + _date_field_to_float, + query_db_field, + serialize_list, +) +from .db.post import PostDBModel + + +class Post(BaseModel): + """ + A post is a type of message that can be updated. Over the get_posts API + we get the latest version of a post. + """ + + item_hash: ItemHash = Field(description="Hash of the content (sha256 by default)") + content: Dict[str, Any] = Field( + description="The content.content of the POST message" + ) + original_item_hash: ItemHash = Field( + description="Hash of the original content (sha256 by default)" + ) + original_type: str = Field( + description="The original, user-generated 'content-type' of the POST message" + ) + address: str = Field(description="The address of the sender of the POST message") + ref: Optional[str] = Field(description="Other message referenced by this one") + channel: Optional[str] = Field( + description="The channel where the POST message was published" + ) + created: datetime = Field(description="The time when the POST message was created") + last_updated: datetime = Field( + description="The time when the POST message was last updated" + ) + + class Config: + allow_extra = False + orm_mode = True + + @classmethod + def from_orm(cls, obj: Any) -> "Post": + if isinstance(obj, PostDBModel): + return Post.parse_obj(model_to_dict(obj)) + return super().from_orm(obj) + + @classmethod + def from_message(cls, message: PostMessage) -> "Post": + return Post.parse_obj( + { + "item_hash": str(message.item_hash), + "content": message.content.content, + "original_item_hash": str(message.item_hash), + "original_type": message.content.type + if hasattr(message.content, "type") + else None, + "address": message.sender, + "ref": message.content.ref if hasattr(message.content, "ref") else None, + "channel": message.channel, + "created": datetime.fromtimestamp(message.time), + "last_updated": datetime.fromtimestamp(message.time), + } + ) + + +class PostsResponse(PaginationResponse): + """Response from an Aleph node API on the path /api/v0/posts.json""" + + posts: List[Post] + pagination_item = "posts" + + +class PostFilter: + """ + A collection of filters that can be applied on post queries. + + """ + + types: Optional[Iterable[str]] + refs: Optional[Iterable[str]] + addresses: Optional[Iterable[str]] + tags: Optional[Iterable[str]] + hashes: Optional[Iterable[str]] + channels: Optional[Iterable[str]] + chains: Optional[Iterable[str]] + start_date: Optional[Union[datetime, float]] + end_date: Optional[Union[datetime, float]] + + def __init__( + self, + types: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ): + self.types = types + self.refs = refs + self.addresses = addresses + self.tags = tags + self.hashes = hashes + self.channels = channels + self.chains = chains + self.start_date = start_date + self.end_date = end_date + + def as_http_params(self) -> Dict[str, str]: + """Convert the filters into a dict that can be used by an `aiohttp` client + as `params` to build the HTTP query string. + """ + + partial_result = { + "types": serialize_list(self.types), + "refs": serialize_list(self.refs), + "addresses": serialize_list(self.addresses), + "tags": serialize_list(self.tags), + "hashes": serialize_list(self.hashes), + "channels": serialize_list(self.channels), + "chains": serialize_list(self.chains), + "startDate": _date_field_to_float(self.start_date), + "endDate": _date_field_to_float(self.end_date), + } + + # Ensure all values are strings. + result: Dict[str, str] = {} + + # Drop empty values + for key, value in partial_result.items(): + if value: + assert isinstance(value, str), f"Value must be a string: `{value}`" + result[key] = value + + return result + + def as_db_query(self): + query = PostDBModel.select().order_by(PostDBModel.created.desc()) + conditions = [] + if self.types: + conditions.append(query_db_field(PostDBModel, "original_type", self.types)) + if self.refs: + conditions.append(query_db_field(PostDBModel, "ref", self.refs)) + if self.addresses: + conditions.append(query_db_field(PostDBModel, "address", self.addresses)) + if self.tags: + for tag in self.tags: + conditions.append(PostDBModel.tags.contains(tag)) + if self.hashes: + conditions.append(query_db_field(PostDBModel, "item_hash", self.hashes)) + if self.channels: + conditions.append(query_db_field(PostDBModel, "channel", self.channels)) + if self.chains: + conditions.append(query_db_field(PostDBModel, "chain", self.chains)) + if self.start_date: + conditions.append(PostDBModel.time >= self.start_date) + if self.end_date: + conditions.append(PostDBModel.time <= self.end_date) + + if conditions: + query = query.where(*conditions) + return query diff --git a/src/aleph/sdk/node/__init__.py b/src/aleph/sdk/node.py similarity index 68% rename from src/aleph/sdk/node/__init__.py rename to src/aleph/sdk/node.py index 42ac739b..1e091e49 100644 --- a/src/aleph/sdk/node/__init__.py +++ b/src/aleph/sdk/node.py @@ -23,15 +23,16 @@ from aleph_message.models.execution.base import Encoding from aleph_message.status import MessageStatus -from ..base import BaseAlephClient, BaseAuthenticatedAlephClient -from ..client import AuthenticatedAlephClient -from ..conf import settings -from ..exceptions import MessageNotFoundError -from ..models import PostsResponse -from ..types import GenericMessage, StorageEnum -from .common import db -from .message import MessageModel, get_message_query, message_to_model, model_to_message -from .post import PostModel, get_post_query, message_to_post, model_to_post +from .base import BaseAlephClient, BaseAuthenticatedAlephClient +from .client import AuthenticatedAlephClient +from .conf import settings +from .exceptions import MessageNotFoundError +from .models.db.common import db +from .models.db.message import MessageDBModel +from .models.db.post import PostDBModel +from .models.message import MessageFilter, message_to_model, model_to_message +from .models.post import Post, PostFilter, PostsResponse +from .types import GenericMessage, StorageEnum class MessageCache(BaseAlephClient): @@ -48,10 +49,10 @@ class MessageCache(BaseAlephClient): def __init__(self): if db.is_closed(): db.connect() - if not MessageModel.table_exists(): - db.create_tables([MessageModel]) - if not PostModel.table_exists(): - db.create_tables([PostModel]) + if not MessageDBModel.table_exists(): + db.create_tables([MessageDBModel]) + if not PostDBModel.table_exists(): + db.create_tables([PostDBModel]) MessageCache._instance_count += 1 @@ -63,29 +64,31 @@ def __del__(self): def __getitem__(self, item_hash: Union[ItemHash, str]) -> Optional[AlephMessage]: try: - item = MessageModel.get(MessageModel.item_hash == str(item_hash)) - except MessageModel.DoesNotExist: + item = MessageDBModel.get(MessageDBModel.item_hash == str(item_hash)) + except MessageDBModel.DoesNotExist: return None return model_to_message(item) def __delitem__(self, item_hash: Union[ItemHash, str]): - MessageModel.delete().where(MessageModel.item_hash == str(item_hash)).execute() + MessageDBModel.delete().where( + MessageDBModel.item_hash == str(item_hash) + ).execute() def __contains__(self, item_hash: Union[ItemHash, str]) -> bool: return ( - MessageModel.select() - .where(MessageModel.item_hash == str(item_hash)) + MessageDBModel.select() + .where(MessageDBModel.item_hash == str(item_hash)) .exists() ) def __len__(self): - return MessageModel.select().count() + return MessageDBModel.select().count() def __iter__(self) -> Iterator[AlephMessage]: """ Iterate over all messages in the cache, the latest first. """ - for item in iter(MessageModel.select().order_by(-MessageModel.time)): + for item in iter(MessageDBModel.select().order_by(-MessageDBModel.time)): yield model_to_message(item) def __repr__(self) -> str: @@ -95,15 +98,19 @@ def __str__(self) -> str: return repr(self) def add(self, messages: Union[AlephMessage, Iterable[AlephMessage]]): + """ + Add a message or a list of messages to the cache. If the message is a post, it will also be added to the + PostDBModel. Any subsequent amend messages will be used to update the original post in the PostDBModel. + """ if isinstance(messages, typing.get_args(AlephMessage)): messages = [messages] messages = list(messages) message_data = (message_to_model(message) for message in messages) - MessageModel.insert_many(message_data).on_conflict_replace().execute() + MessageDBModel.insert_many(message_data).on_conflict_replace().execute() - # Add posts and their amends to the PostModel + # Add posts and their amends to the PostDBModel post_data = [] amend_messages = [] for message in messages: @@ -112,7 +119,7 @@ def add(self, messages: Union[AlephMessage, Iterable[AlephMessage]]): if message.content.type == "amend": amend_messages.append(message) continue - post = message_to_post(message).dict() + post = Post.from_message(message).dict() post["chain"] = message.chain.value post["tags"] = message.content.content.get("tags", None) post_data.append(post) @@ -120,14 +127,14 @@ def add(self, messages: Union[AlephMessage, Iterable[AlephMessage]]): if message.item_hash in self.missing_posts: amend_messages += self.missing_posts.pop(message.item_hash) - PostModel.insert_many(post_data).on_conflict_replace().execute() + PostDBModel.insert_many(post_data).on_conflict_replace().execute() # Handle amends in second step to avoid missing original posts - post_data = [] for message in amend_messages: + logging.debug(f"Adding amend {message.item_hash} to cache") # Find the original post and update it - original_post = MessageModel.get( - MessageModel.item_hash == message.content.ref + original_post = PostDBModel.get( + PostDBModel.item_hash == message.content.ref ) if not original_post: latest_amend = self.missing_posts.get(ItemHash(message.content.ref)) @@ -136,16 +143,13 @@ def add(self, messages: Union[AlephMessage, Iterable[AlephMessage]]): continue if datetime.fromtimestamp(message.time) < original_post.last_updated: continue - original_post.item_hash = message.item_hash original_post.content = message.content.content original_post.original_item_hash = message.content.ref original_post.original_type = message.content.type original_post.address = message.sender original_post.channel = message.channel original_post.last_updated = datetime.fromtimestamp(message.time) - post_data.append(original_post) - - PostModel.insert_many(post_data).on_conflict_replace().execute() + original_post.save() def get( self, item_hashes: Union[Union[ItemHash, str], Iterable[Union[ItemHash, str]]] @@ -157,8 +161,8 @@ def get( item_hashes = [item_hashes] item_hashes = [str(item_hash) for item_hash in item_hashes] items = ( - MessageModel.select() - .where(MessageModel.item_hash.in_(item_hashes)) + MessageDBModel.select() + .where(MessageDBModel.item_hash.in_(item_hashes)) .execute() ) return [model_to_message(item) for item in items] @@ -171,7 +175,7 @@ def listen_to(self, message_stream: AsyncIterable[AlephMessage]) -> Coroutine: async def _listen(): async for message in message_stream: self.add(message) - print(f"Added message {message.item_hash} to cache") + logging.info(f"Added message {message.item_hash} to cache") return _listen() @@ -179,11 +183,11 @@ async def fetch_aggregate( self, address: str, key: str, limit: int = 100 ) -> Dict[str, Dict]: item = ( - MessageModel.select() - .where(MessageModel.type == MessageType.aggregate.value) - .where(MessageModel.sender == address) - .where(MessageModel.key == key) - .order_by(MessageModel.time.desc()) + MessageDBModel.select() + .where(MessageDBModel.type == MessageType.aggregate.value) + .where(MessageDBModel.sender == address) + .where(MessageDBModel.key == key) + .order_by(MessageDBModel.time.desc()) .first() ) return item.content["content"] @@ -192,13 +196,13 @@ async def fetch_aggregates( self, address: str, keys: Optional[Iterable[str]] = None, limit: int = 100 ) -> Dict[str, Dict]: query = ( - MessageModel.select() - .where(MessageModel.type == MessageType.aggregate.value) - .where(MessageModel.sender == address) - .order_by(MessageModel.time.desc()) + MessageDBModel.select() + .where(MessageDBModel.type == MessageType.aggregate.value) + .where(MessageDBModel.sender == address) + .order_by(MessageDBModel.time.desc()) ) if keys: - query = query.where(MessageModel.key.in_(keys)) + query = query.where(MessageDBModel.key.in_(keys)) query = query.limit(limit) return {item.key: item.content["content"] for item in list(query)} @@ -206,33 +210,17 @@ async def get_posts( self, pagination: int = 200, page: int = 1, - types: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + post_filter: Optional[PostFilter] = None, ignore_invalid_messages: Optional[bool] = True, invalid_messages_log_level: Optional[int] = logging.NOTSET, ) -> PostsResponse: - query = get_post_query( - types=types, - refs=refs, - addresses=addresses, - tags=tags, - hashes=hashes, - channels=channels, - chains=chains, - start_date=start_date, - end_date=end_date, - ) + if not post_filter: + post_filter = PostFilter() + query = post_filter.as_db_query() query = query.paginate(page, pagination) - posts = [model_to_post(item) for item in list(query)] + posts = [Post.from_orm(item) for item in list(query)] return PostsResponse( posts=posts, @@ -249,38 +237,17 @@ async def get_messages( self, pagination: int = 200, page: int = 1, - message_type: Optional[MessageType] = None, - message_types: Optional[Iterable[MessageType]] = None, - content_types: Optional[Iterable[str]] = None, - content_keys: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + message_filter: Optional[MessageFilter] = None, ignore_invalid_messages: Optional[bool] = True, invalid_messages_log_level: Optional[int] = logging.NOTSET, ) -> MessagesResponse: """ Get many messages from the cache. """ - message_types = message_types or [message_type] if message_type else None - query = get_message_query( - message_types=message_types, - content_keys=content_keys, - content_types=content_types, - refs=refs, - addresses=addresses, - tags=tags, - hashes=hashes, - channels=channels, - chains=chains, - start_date=start_date, - end_date=end_date, - ) + if not message_filter: + message_filter = MessageFilter() + + query = message_filter.as_db_query() query = query.paginate(page, pagination) @@ -303,12 +270,12 @@ async def get_message( """ Get a single message from the cache. """ - query = MessageModel.select().where(MessageModel.item_hash == item_hash) + query = MessageDBModel.select().where(MessageDBModel.item_hash == item_hash) if message_type: - query = query.where(MessageModel.type == message_type.value) + query = query.where(MessageDBModel.type == message_type.value) if channel: - query = query.where(MessageModel.channel == channel) + query = query.where(MessageDBModel.channel == channel) item = query.first() @@ -319,36 +286,15 @@ async def get_message( async def watch_messages( self, - message_type: Optional[MessageType] = None, - message_types: Optional[Iterable[MessageType]] = None, - content_types: Optional[Iterable[str]] = None, - content_keys: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + message_filter: Optional[MessageFilter] = None, ) -> AsyncIterable[AlephMessage]: """ Watch messages from the cache. """ - message_types = message_types or [message_type] if message_type else None - query = get_message_query( - message_types=message_types, - content_keys=content_keys, - content_types=content_types, - refs=refs, - addresses=addresses, - tags=tags, - hashes=hashes, - channels=channels, - chains=chains, - start_date=start_date, - end_date=end_date, - ) + if not message_filter: + message_filter = MessageFilter() + + query = message_filter.as_db_query() async for item in query: yield model_to_message(item) @@ -362,42 +308,49 @@ class DomainNode(MessageCache, BaseAuthenticatedAlephClient): It synchronizes with the network on a subset of the messages (the "domain") by listening to the network and storing the messages in the cache. The user may define the domain by specifying a channels, tags, senders, chains and/or message types. + + By default, the domain is defined by the user's own address and used chain, meaning that the DomainNode will only + store and create messages that are sent by the user. """ + session: AuthenticatedAlephClient + message_filter: MessageFilter + watch_task: Optional[asyncio.Task] = None + def __init__( self, session: AuthenticatedAlephClient, - channels: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - chains: Optional[Iterable[Chain]] = None, - message_types: Optional[Iterable[MessageType]] = None, + message_filter: Optional[MessageFilter] = None, ): super().__init__() self.session = session - self.channels = channels - self.tags = tags - self.addresses = ( - list(addresses) + [session.account.get_address()] - if addresses - else [session.account.get_address()] + if not message_filter: + message_filter = MessageFilter() + message_filter.addresses = list( + set( + ( + list(message_filter.addresses) + [session.account.get_address()] + if message_filter.addresses + else [session.account.get_address()] + ) + ) ) - self.chains = ( - list(chains) + [Chain(session.account.CHAIN)] - if chains - else [session.account.CHAIN] + message_filter.chains = list( + set( + ( + list(message_filter.chains) + [Chain(session.account.CHAIN)] + if message_filter.chains + else [session.account.CHAIN] + ) + ) ) - self.message_types = message_types + self.message_filter = message_filter # start listening to the network and storing messages in the cache - asyncio.get_event_loop().create_task( + self.watch_task = asyncio.get_event_loop().create_task( self.listen_to( self.session.watch_messages( - channels=self.channels, - tags=self.tags, - addresses=self.addresses, - chains=self.chains, - message_types=self.message_types, + message_filter=self.message_filter, ) ) ) @@ -405,29 +358,31 @@ def __init__( # synchronize with past messages asyncio.get_event_loop().run_until_complete( self.synchronize( - channels=self.channels, - tags=self.tags, - addresses=self.addresses, - chains=self.chains, - message_types=self.message_types, + message_filter=self.message_filter, ) ) + def __del__(self): + if self.watch_task: + self.watch_task.cancel() + + def __exit__(self, exc_type, exc_val, exc_tb): + close_fut = self.session.__aexit__(exc_type, exc_val, exc_tb) + try: + loop = asyncio.get_running_loop() + loop.run_until_complete(close_fut) + except RuntimeError: + asyncio.run(close_fut) + async def __aenter__(self) -> "DomainNode": return self async def __aexit__(self, exc_type, exc_val, exc_tb): - ... + await self.session.__aexit__(exc_type, exc_val, exc_tb) async def synchronize( self, - channels: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - chains: Optional[Iterable[Chain]] = None, - message_types: Optional[Iterable[MessageType]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + message_filter: MessageFilter, ): """ Synchronize with past messages. @@ -435,13 +390,7 @@ async def synchronize( chunk_size = 200 messages = [] async for message in self.session.get_messages_iterator( - channels=channels, - tags=tags, - addresses=addresses, - chains=chains, - message_types=message_types, - start_date=start_date, - end_date=end_date, + message_filter=message_filter ): messages.append(message) if len(messages) >= chunk_size: @@ -475,22 +424,33 @@ def check_validity( channel: Optional[str] = None, content: Optional[Dict] = None, ): - if self.message_types and message_type not in self.message_types: + if ( + self.message_filter.message_types + and message_type not in self.message_filter.message_types + ): raise ValueError( f"Cannot create {message_type.value} message because DomainNode is not listening to post messages." ) - if address and self.addresses and address not in self.addresses: + if ( + address + and self.message_filter.addresses + and address not in self.message_filter.addresses + ): raise ValueError( f"Cannot create {message_type.value} message because DomainNode is not listening to messages from address {address}." ) - if channel and self.channels and channel not in self.channels: + if ( + channel + and self.message_filter.channels + and channel not in self.message_filter.channels + ): raise ValueError( f"Cannot create {message_type.value} message because DomainNode is not listening to messages from channel {channel}." ) if ( content - and self.tags - and not set(content.get("tags", [])).intersection(self.tags) + and self.message_filter.tags + and not set(content.get("tags", [])).intersection(self.message_filter.tags) ): raise ValueError( f"Cannot create {message_type.value} message because DomainNode is not listening to any of these tags: {content.get('tags', [])}." @@ -518,6 +478,7 @@ async def create_post( storage_engine=storage_engine, sync=sync, ) + print(resp) # WARNING: this can cause inconsistencies if the message is dropped/rejected by the aleph node if status in [MessageStatus.PENDING, MessageStatus.PROCESSED]: self.add(resp) diff --git a/src/aleph/sdk/node/message.py b/src/aleph/sdk/node/message.py deleted file mode 100644 index a3327d2a..00000000 --- a/src/aleph/sdk/node/message.py +++ /dev/null @@ -1,137 +0,0 @@ -from datetime import datetime -from typing import Any, Dict, Iterable, Optional, Union - -from aleph_message import parse_message -from aleph_message.models import AlephMessage, MessageConfirmation, MessageType -from peewee import BooleanField, CharField, FloatField, IntegerField, Model -from playhouse.shortcuts import model_to_dict -from playhouse.sqlite_ext import JSONField - -from aleph.sdk.node.common import PydanticField, db, pydantic_json_dumps - - -class MessageModel(Model): - """ - A simple database model for storing AlephMessage objects. - """ - - item_hash = CharField(primary_key=True) - chain = CharField(5) - type = CharField(9) - sender = CharField() - channel = CharField(null=True) - confirmations: PydanticField[MessageConfirmation] = PydanticField( - type=MessageConfirmation, null=True - ) - confirmed = BooleanField(null=True) - signature = CharField(null=True) - size = IntegerField(null=True) - time = FloatField() - item_type = CharField(7) - item_content = CharField(null=True) - hash_type = CharField(6, null=True) - content = JSONField(json_dumps=pydantic_json_dumps) - forgotten_by = CharField(null=True) - tags = JSONField(json_dumps=pydantic_json_dumps, null=True) - key = CharField(null=True) - ref = CharField(null=True) - content_type = CharField(null=True) - - class Meta: - database = db - - -def message_to_model(message: AlephMessage) -> Dict: - return { - "item_hash": str(message.item_hash), - "chain": message.chain, - "type": message.type, - "sender": message.sender, - "channel": message.channel, - "confirmations": message.confirmations[0] if message.confirmations else None, - "confirmed": message.confirmed, - "signature": message.signature, - "size": message.size, - "time": message.time, - "item_type": message.item_type, - "item_content": message.item_content, - "hash_type": message.hash_type, - "content": message.content, - "forgotten_by": message.forgotten_by[0] if message.forgotten_by else None, - "tags": message.content.content.get("tags", None) - if hasattr(message.content, "content") - else None, - "key": message.content.key if hasattr(message.content, "key") else None, - "ref": message.content.ref if hasattr(message.content, "ref") else None, - "content_type": message.content.type - if hasattr(message.content, "type") - else None, - } - - -def model_to_message(item: Any) -> AlephMessage: - item.confirmations = [item.confirmations] if item.confirmations else [] - item.forgotten_by = [item.forgotten_by] if item.forgotten_by else None - - to_exclude = [ - MessageModel.tags, - MessageModel.ref, - MessageModel.key, - MessageModel.content_type, - ] - - item_dict = model_to_dict(item, exclude=to_exclude) - return parse_message(item_dict) - - -def query_field(field_name, field_values: Iterable[str]): - field = getattr(MessageModel, field_name) - values = list(field_values) - - if len(values) == 1: - return field == values[0] - return field.in_(values) - - -def get_message_query( - message_types: Optional[Iterable[MessageType]] = None, - content_keys: Optional[Iterable[str]] = None, - content_types: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, -): - query = MessageModel.select().order_by(MessageModel.time.desc()) - conditions = [] - if message_types: - conditions.append(query_field("type", [type.value for type in message_types])) - if content_keys: - conditions.append(query_field("key", content_keys)) - if content_types: - conditions.append(query_field("content_type", content_types)) - if refs: - conditions.append(query_field("ref", refs)) - if addresses: - conditions.append(query_field("sender", addresses)) - if tags: - for tag in tags: - conditions.append(MessageModel.tags.contains(tag)) - if hashes: - conditions.append(query_field("item_hash", hashes)) - if channels: - conditions.append(query_field("channel", channels)) - if chains: - conditions.append(query_field("chain", chains)) - if start_date: - conditions.append(MessageModel.time >= start_date) - if end_date: - conditions.append(MessageModel.time <= end_date) - - if conditions: - query = query.where(*conditions) - return query diff --git a/src/aleph/sdk/node/post.py b/src/aleph/sdk/node/post.py deleted file mode 100644 index b68a421d..00000000 --- a/src/aleph/sdk/node/post.py +++ /dev/null @@ -1,115 +0,0 @@ -from datetime import datetime -from typing import Any, Dict, Iterable, Optional, Union - -from aleph_message.models import PostMessage -from peewee import CharField, DateTimeField, Model -from playhouse.shortcuts import model_to_dict -from playhouse.sqlite_ext import JSONField - -from aleph.sdk.models import Post -from aleph.sdk.node.common import db, pydantic_json_dumps - - -class PostModel(Model): - """ - A simple database model for storing AlephMessage objects. - """ - - original_item_hash = CharField(primary_key=True) - item_hash = CharField() - content = JSONField(json_dumps=pydantic_json_dumps) - original_type = CharField() - address = CharField() - ref = CharField(null=True) - channel = CharField(null=True) - created = DateTimeField() - last_updated = DateTimeField() - tags = JSONField(json_dumps=pydantic_json_dumps, null=True) - chain = CharField(5) - - class Meta: - database = db - - -def post_to_model(post: Post) -> Dict: - return { - "item_hash": str(post.item_hash), - "content": post.content, - "original_item_hash": str(post.original_item_hash), - "original_type": post.original_type, - "address": post.address, - "ref": post.ref, - "channel": post.channel, - "created": post.created, - "last_updated": post.last_updated, - } - - -def message_to_post(message: PostMessage) -> Post: - return Post.parse_obj( - { - "item_hash": str(message.item_hash), - "content": message.content, - "original_item_hash": str(message.item_hash), - "original_type": message.content.type - if hasattr(message.content, "type") - else None, - "address": message.sender, - "ref": message.content.ref if hasattr(message.content, "ref") else None, - "channel": message.channel, - "created": datetime.fromtimestamp(message.time), - "last_updated": datetime.fromtimestamp(message.time), - } - ) - - -def model_to_post(item: Any) -> Post: - to_exclude = [PostModel.tags, PostModel.chain] - return Post.parse_obj(model_to_dict(item, exclude=to_exclude)) - - -def query_field(field_name, field_values: Iterable[str]): - field = getattr(PostModel, field_name) - values = list(field_values) - - if len(values) == 1: - return field == values[0] - return field.in_(values) - - -def get_post_query( - types: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, -): - query = PostModel.select().order_by(PostModel.created.desc()) - conditions = [] - if types: - conditions.append(query_field("original_type", types)) - if refs: - conditions.append(query_field("ref", refs)) - if addresses: - conditions.append(query_field("address", addresses)) - if tags: - for tag in tags: - conditions.append(PostModel.tags.contains(tag)) - if hashes: - conditions.append(query_field("item_hash", hashes)) - if channels: - conditions.append(query_field("channel", channels)) - if chains: - conditions.append(query_field("chain", chains)) - if start_date: - conditions.append(PostModel.time >= start_date) - if end_date: - conditions.append(PostModel.time <= end_date) - - if conditions: - query = query.where(*conditions) - return query diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 9a2a41c7..a51b1483 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,7 +1,7 @@ import json from pathlib import Path from tempfile import NamedTemporaryFile -from typing import List +from typing import Any, Callable, Dict, List import pytest as pytest from aleph_message.models import AggregateMessage, AlephMessage, PostMessage @@ -112,11 +112,13 @@ def aleph_messages() -> List[AlephMessage]: @pytest.fixture -def raw_messages_response(aleph_messages): +def raw_messages_response(aleph_messages) -> Callable[[int], Dict[str, Any]]: return lambda page: { - "messages": [message.dict() for message in aleph_messages] if page == 1 else [], + "messages": [message.dict() for message in aleph_messages] + if int(page) == 1 + else [], "pagination_item": "messages", - "pagination_page": page, + "pagination_page": int(page), "pagination_per_page": max(len(aleph_messages), 20), "pagination_total": len(aleph_messages) if page == 1 else 0, } diff --git a/tests/unit/test_asynchronous_get.py b/tests/unit/test_asynchronous_get.py index db788e0b..72c47706 100644 --- a/tests/unit/test_asynchronous_get.py +++ b/tests/unit/test_asynchronous_get.py @@ -3,11 +3,12 @@ from unittest.mock import AsyncMock import pytest -from aleph_message.models import MessagesResponse +from aleph_message.models import MessagesResponse, MessageType from aleph.sdk.client import AlephClient from aleph.sdk.conf import settings -from aleph.sdk.models import PostsResponse +from aleph.sdk.models.message import MessageFilter +from aleph.sdk.models.post import PostFilter, PostsResponse def make_mock_session(get_return_value: Dict[str, Any]) -> AlephClient: @@ -67,7 +68,12 @@ async def test_fetch_aggregates(): @pytest.mark.asyncio async def test_get_posts(): async with AlephClient(api_server=settings.API_HOST) as session: - response: PostsResponse = await session.get_posts() + response: PostsResponse = await session.get_posts( + pagination=2, + post_filter=PostFilter( + channels=["TEST"], + ), + ) posts = response.posts assert len(posts) > 1 @@ -78,6 +84,9 @@ async def test_get_messages(): async with AlephClient(api_server=settings.API_HOST) as session: response: MessagesResponse = await session.get_messages( pagination=2, + message_filter=MessageFilter( + message_types=[MessageType.post], + ), ) messages = response.messages diff --git a/tests/unit/test_node.py b/tests/unit/test_node.py index 6ea9d44c..f01399c2 100644 --- a/tests/unit/test_node.py +++ b/tests/unit/test_node.py @@ -1,12 +1,13 @@ import json import os from pathlib import Path -from typing import Any, Dict +from typing import Any, Callable, Dict, List from unittest.mock import AsyncMock, MagicMock import pytest as pytest from aleph_message.models import ( AggregateMessage, + AlephMessage, ForgetMessage, MessageType, PostMessage, @@ -17,6 +18,7 @@ from aleph.sdk import AuthenticatedAlephClient from aleph.sdk.conf import settings +from aleph.sdk.models.post import PostFilter from aleph.sdk.node import DomainNode from aleph.sdk.types import Account, StorageEnum @@ -54,7 +56,7 @@ async def text(self): class MockGetResponse: - def __init__(self, response_message, page=1): + def __init__(self, response_message: Callable[[int], Dict[str, Any]], page=1): self.response_message = response_message self.page = page @@ -76,9 +78,32 @@ async def json(self): return self.response_message(self.page) +class MockWsConnection: + def __init__(self, messages: List[AlephMessage]): + self.messages = messages + self.i = 0 + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + ... + + def __aiter__(self): + return self + + def __anext__(self): + try: + message = self.messages[self.i] + self.i += 1 + return message + except IndexError: + raise StopAsyncIteration + + @pytest.fixture def mock_session_with_two_messages( - ethereum_account: Account, raw_messages_response: Dict[str, Any] + ethereum_account: Account, raw_messages_response: Callable[[int], Dict[str, Any]] ) -> AuthenticatedAlephClient: http_session = AsyncMock() http_session.post = MagicMock() @@ -97,6 +122,10 @@ def mock_session_with_two_messages( response_message=raw_messages_response, page=kwargs.get("params", {}).get("page", 1), ) + http_session.ws_connect = MagicMock() + http_session.ws_connect.side_effect = lambda *args, **kwargs: MockWsConnection( + messages=raw_messages_response(1)["messages"] + ) client = AuthenticatedAlephClient( account=ethereum_account, api_server="http://localhost" @@ -106,9 +135,12 @@ def mock_session_with_two_messages( return client -@pytest.mark.asyncio -def test_node_init(mock_session_with_two_messages): - node = DomainNode(session=mock_session_with_two_messages) +def test_node_init(mock_session_with_two_messages, aleph_messages): + node = DomainNode( + session=mock_session_with_two_messages, + ) + assert mock_session_with_two_messages.http_session.get.called_once + assert mock_session_with_two_messages.http_session.ws_connect.called_once assert node.session == mock_session_with_two_messages assert len(node) >= 2 @@ -257,3 +289,40 @@ async def test_submit_message(mock_node_with_post_success): assert mock_node_with_post_success.session.http_session.post.called_once assert message.content.content == content assert status == MessageStatus.PENDING + + +@pytest.mark.asyncio +async def test_amend_post(mock_node_with_post_success): + async with mock_node_with_post_success as node: + post_message, status = await node.create_post( + post_content={ + "Hello": "World", + }, + post_type="to-be-amended", + channel="TEST", + ) + + assert mock_node_with_post_success.session.http_session.post.called_once + assert post_message.content.content == {"Hello": "World"} + assert status == MessageStatus.PENDING + + async with mock_node_with_post_success as node: + amend_message, status = await node.create_post( + post_content={ + "Hello": "World", + "Foo": "Bar", + }, + post_type="amend", + ref=post_message.item_hash, + channel="TEST", + ) + + async with mock_node_with_post_success as node: + posts = ( + await node.get_posts( + post_filter=PostFilter( + hashes=[post_message.item_hash], + ) + ) + ).posts + assert posts[0].content == {"Hello": "World", "Foo": "Bar"} diff --git a/tests/unit/test_node_get.py b/tests/unit/test_node_get.py index 7c7c7678..732e5186 100644 --- a/tests/unit/test_node_get.py +++ b/tests/unit/test_node_get.py @@ -13,8 +13,9 @@ from aleph.sdk.chains.ethereum import get_fallback_account from aleph.sdk.exceptions import MessageNotFoundError +from aleph.sdk.models.message import MessageFilter +from aleph.sdk.models.post import Post, PostFilter from aleph.sdk.node import MessageCache -from aleph.sdk.node.post import message_to_post @pytest.mark.asyncio @@ -22,11 +23,6 @@ async def test_base(aleph_messages): # test add_many cache = MessageCache() cache.add(aleph_messages) - assert len(cache) == len(aleph_messages) - - item_hashes = [message.item_hash for message in aleph_messages] - cached_messages = cache.get(item_hashes) - assert len(cached_messages) == len(aleph_messages) for message in aleph_messages: assert cache[message.item_hash] == message @@ -63,35 +59,62 @@ async def test_iterate(self): @pytest.mark.asyncio async def test_addresses(self): - items = ( - await self.cache.get_messages(addresses=[self.messages[0].sender]) - ).messages - assert items[0] == self.messages[0] + assert ( + self.messages[0] + in ( + await self.cache.get_messages( + message_filter=MessageFilter( + addresses=[self.messages[0].sender], + ) + ) + ).messages + ) @pytest.mark.asyncio async def test_tags(self): assert ( - len((await self.cache.get_messages(tags=["thistagdoesnotexist"])).messages) + len( + ( + await self.cache.get_messages( + message_filter=MessageFilter(tags=["thistagdoesnotexist"]) + ) + ).messages + ) == 0 ) @pytest.mark.asyncio async def test_message_type(self): - assert (await self.cache.get_messages(message_type=MessageType.post)).messages[ - 0 - ] == self.messages[1] + assert ( + self.messages[1] + in ( + await self.cache.get_messages( + message_filter=MessageFilter(message_types=[MessageType.post]) + ) + ).messages + ) @pytest.mark.asyncio async def test_refs(self): assert ( - await self.cache.get_messages(refs=[self.messages[1].content.ref]) - ).messages[0] == self.messages[1] + self.messages[1] + in ( + await self.cache.get_messages( + message_filter=MessageFilter(refs=[self.messages[1].content.ref]) + ) + ).messages + ) @pytest.mark.asyncio async def test_hashes(self): assert ( - await self.cache.get_messages(hashes=[self.messages[0].item_hash]) - ).messages[0] == self.messages[0] + self.messages[0] + in ( + await self.cache.get_messages( + message_filter=MessageFilter(hashes=[self.messages[0].item_hash]) + ) + ).messages + ) @pytest.mark.asyncio async def test_pagination(self): @@ -100,26 +123,50 @@ async def test_pagination(self): @pytest.mark.asyncio async def test_content_types(self): assert ( - await self.cache.get_messages(content_types=[self.messages[1].content.type]) - ).messages[0] == self.messages[1] + self.messages[1] + in ( + await self.cache.get_messages( + message_filter=MessageFilter( + content_types=[self.messages[1].content.type] + ) + ) + ).messages + ) @pytest.mark.asyncio async def test_channels(self): assert ( - await self.cache.get_messages(channels=[self.messages[1].channel]) - ).messages[0] == self.messages[1] + self.messages[1] + in ( + await self.cache.get_messages( + message_filter=MessageFilter(channels=[self.messages[1].channel]) + ) + ).messages + ) @pytest.mark.asyncio async def test_chains(self): assert ( - await self.cache.get_messages(chains=[self.messages[1].chain]) - ).messages[0] == self.messages[1] + self.messages[1] + in ( + await self.cache.get_messages( + message_filter=MessageFilter(chains=[self.messages[1].chain]) + ) + ).messages + ) @pytest.mark.asyncio async def test_content_keys(self): assert ( - await self.cache.get_messages(content_keys=[self.messages[0].content.key]) - ).messages[0] == self.messages[0] + self.messages[0] + in ( + await self.cache.get_messages( + message_filter=MessageFilter( + content_keys=[self.messages[0].content.key] + ) + ) + ).messages + ) class TestPostQueries: @@ -137,33 +184,62 @@ def class_teardown(self): @pytest.mark.asyncio async def test_addresses(self): - items = (await self.cache.get_posts(addresses=[self.messages[1].sender])).posts - assert items[0] == message_to_post(self.messages[1]) + assert ( + Post.from_message(self.messages[1]) + in ( + await self.cache.get_posts( + post_filter=PostFilter(addresses=[self.messages[1].sender]) + ) + ).posts + ) @pytest.mark.asyncio async def test_tags(self): assert ( - len((await self.cache.get_posts(tags=["thistagdoesnotexist"])).posts) == 0 + len( + ( + await self.cache.get_posts( + post_filter=PostFilter(tags=["thistagdoesnotexist"]) + ) + ).posts + ) + == 0 ) @pytest.mark.asyncio async def test_types(self): assert ( - len((await self.cache.get_posts(types=["thistypedoesnotexist"])).posts) == 0 + len( + ( + await self.cache.get_posts( + post_filter=PostFilter(types=["thistypedoesnotexist"]) + ) + ).posts + ) + == 0 ) @pytest.mark.asyncio async def test_channels(self): - print(self.messages[1]) - assert (await self.cache.get_posts(channels=[self.messages[1].channel])).posts[ - 0 - ] == message_to_post(self.messages[1]) + assert ( + Post.from_message(self.messages[1]) + in ( + await self.cache.get_posts( + post_filter=PostFilter(channels=[self.messages[1].channel]) + ) + ).posts + ) @pytest.mark.asyncio async def test_chains(self): - assert (await self.cache.get_posts(chains=[self.messages[1].chain])).posts[ - 0 - ] == message_to_post(self.messages[1]) + assert ( + Post.from_message(self.messages[1]) + in ( + await self.cache.get_posts( + post_filter=PostFilter(chains=[self.messages[1].chain]) + ) + ).posts + ) @pytest.mark.asyncio diff --git a/tests/unit/test_synchronous_get.py b/tests/unit/test_synchronous_get.py index eee26dcf..0788a1ab 100644 --- a/tests/unit/test_synchronous_get.py +++ b/tests/unit/test_synchronous_get.py @@ -2,14 +2,16 @@ from aleph.sdk.client import AlephClient from aleph.sdk.conf import settings +from aleph.sdk.models.message import MessageFilter def test_get_post_messages(): with AlephClient(api_server=settings.API_HOST) as session: - # TODO: Remove deprecated message_type parameter after message_types changes on pyaleph are deployed response: MessagesResponse = session.get_messages( pagination=2, - message_type=MessageType.post, + message_filter=MessageFilter( + message_types=[MessageType.post], + ), ) messages = response.messages