From 1e0d29e64da5f3e45e9ecc809f8cef7a6c49b085 Mon Sep 17 00:00:00 2001 From: amalcaraz Date: Tue, 24 Jun 2025 16:56:39 +0200 Subject: [PATCH 1/2] feat: Implement historical pricing for message cost recalculation --- ...1c06d0ade60c_calculate_costs_statically.py | 50 +-- src/aleph/services/pricing_utils.py | 114 +++++ src/aleph/toolkit/constants.py | 18 +- src/aleph/web/controllers/prices.py | 116 ++++- src/aleph/web/controllers/routes.py | 2 + tests/api/test_pricing_recalculation.py | 410 ++++++++++++++++++ tests/services/test_pricing_utils.py | 401 +++++++++++++++++ 7 files changed, 1058 insertions(+), 53 deletions(-) create mode 100644 src/aleph/services/pricing_utils.py create mode 100644 tests/api/test_pricing_recalculation.py create mode 100644 tests/services/test_pricing_utils.py diff --git a/deployment/migrations/versions/0033_1c06d0ade60c_calculate_costs_statically.py b/deployment/migrations/versions/0033_1c06d0ade60c_calculate_costs_statically.py index b1f36c93..ea44cab1 100644 --- a/deployment/migrations/versions/0033_1c06d0ade60c_calculate_costs_statically.py +++ b/deployment/migrations/versions/0033_1c06d0ade60c_calculate_costs_statically.py @@ -17,7 +17,8 @@ from aleph.db.accessors.cost import make_costs_upsert_query from aleph.db.accessors.messages import get_message_by_item_hash from aleph.services.cost import _is_confidential_vm, get_detailed_costs, CostComputableContent -from aleph.types.cost import ProductComputeUnit, ProductPrice, ProductPriceOptions, ProductPriceType, ProductPricing +from aleph.services.pricing_utils import build_default_pricing_model +from aleph.types.cost import ProductPriceType, ProductPricing from aleph.types.db_session import DbSession logger = logging.getLogger("alembic") @@ -30,48 +31,6 @@ depends_on = None -hardcoded_initial_price: Dict[ProductPriceType, ProductPricing] = { - ProductPriceType.PROGRAM: ProductPricing( - ProductPriceType.PROGRAM, - ProductPrice( - ProductPriceOptions("0.05", "0.000000977"), - ProductPriceOptions("200", "0.011") - ), - ProductComputeUnit(1, 2048, 2048) - ), - ProductPriceType.PROGRAM_PERSISTENT: ProductPricing( - ProductPriceType.PROGRAM_PERSISTENT, - ProductPrice( - ProductPriceOptions("0.05", "0.000000977"), - ProductPriceOptions("1000", "0.055") - ), - ProductComputeUnit(1, 20480, 2048) - ), - ProductPriceType.INSTANCE: ProductPricing( - ProductPriceType.INSTANCE, - ProductPrice( - ProductPriceOptions("0.05", "0.000000977"), - ProductPriceOptions("1000", "0.055") - ), - ProductComputeUnit(1, 20480, 2048) - ), - ProductPriceType.INSTANCE_CONFIDENTIAL: ProductPricing( - ProductPriceType.INSTANCE_CONFIDENTIAL, - ProductPrice( - ProductPriceOptions("0.05", "0.000000977"), - ProductPriceOptions("2000", "0.11") - ), - ProductComputeUnit(1, 20480, 2048) - ), - ProductPriceType.STORAGE: ProductPricing( - ProductPriceType.STORAGE, - ProductPrice( - ProductPriceOptions("0.333333333"), - ) - ), -} - - def _get_product_instance_type( content: InstanceContent ) -> ProductPriceType: @@ -112,12 +71,15 @@ def do_calculate_costs() -> None: logger.debug("INIT: CALCULATE COSTS FOR: %r", msg_item_hashes) + # Build the initial pricing model from DEFAULT_PRICE_AGGREGATE + initial_pricing_model = build_default_pricing_model() + for item_hash in msg_item_hashes: message = get_message_by_item_hash(session, item_hash) if message: content = message.parsed_content type = _get_product_price_type(content) - pricing = hardcoded_initial_price[type] + pricing = initial_pricing_model[type] costs = get_detailed_costs(session, content, message.item_hash, pricing) if len(costs) > 0: diff --git a/src/aleph/services/pricing_utils.py b/src/aleph/services/pricing_utils.py new file mode 100644 index 00000000..0e5f873b --- /dev/null +++ b/src/aleph/services/pricing_utils.py @@ -0,0 +1,114 @@ +""" +Utility functions for pricing model creation and management. +""" + +import datetime as dt +from typing import Dict, List, Union + +from aleph.db.accessors.aggregates import get_aggregate_elements, merge_aggregate_elements +from aleph.db.models import AggregateElementDb +from aleph.toolkit.constants import ( + DEFAULT_PRICE_AGGREGATE, + PRICE_AGGREGATE_KEY, + PRICE_AGGREGATE_OWNER, +) +from aleph.types.cost import ProductPriceType, ProductPricing +from aleph.types.db_session import DbSession + + +def build_pricing_model_from_aggregate(aggregate_content: Dict[Union[ProductPriceType, str], dict]) -> Dict[ProductPriceType, ProductPricing]: + """ + Build a complete pricing model from an aggregate content dictionary. + + This function converts the DEFAULT_PRICE_AGGREGATE format or any pricing aggregate + content into a dictionary of ProductPricing objects that can be used by the cost + calculation functions. + + Args: + aggregate_content: Dictionary containing pricing information with ProductPriceType as keys + + Returns: + Dictionary mapping ProductPriceType to ProductPricing objects + """ + pricing_model: Dict[ProductPriceType, ProductPricing] = {} + + for price_type, pricing_data in aggregate_content.items(): + try: + price_type = ProductPriceType(price_type) + pricing_model[price_type] = ProductPricing.from_aggregate( + price_type, aggregate_content + ) + except (KeyError, ValueError) as e: + # Log the error but continue processing other price types + import logging + logger = logging.getLogger(__name__) + logger.warning(f"Failed to parse pricing for {price_type}: {e}") + + return pricing_model + + +def build_default_pricing_model() -> Dict[ProductPriceType, ProductPricing]: + """ + Build the default pricing model from DEFAULT_PRICE_AGGREGATE constant. + + Returns: + Dictionary mapping ProductPriceType to ProductPricing objects + """ + return build_pricing_model_from_aggregate(DEFAULT_PRICE_AGGREGATE) + + +def get_pricing_aggregate_history(session: DbSession) -> List[AggregateElementDb]: + """ + Get all pricing aggregate updates in chronological order. + + Args: + session: Database session + + Returns: + List of AggregateElementDb objects ordered by creation_datetime + """ + aggregate_elements = get_aggregate_elements( + session=session, + owner=PRICE_AGGREGATE_OWNER, + key=PRICE_AGGREGATE_KEY + ) + return list(aggregate_elements) + + +def get_pricing_timeline(session: DbSession) -> List[tuple[dt.datetime, Dict[ProductPriceType, ProductPricing]]]: + """ + Get the complete pricing timeline with timestamps and pricing models. + + This function returns a chronologically ordered list of pricing changes, + useful for processing messages in chronological order and applying the + correct pricing at each point in time. + + This properly merges aggregate elements up to each point in time to create + the cumulative pricing state, similar to how _update_aggregate works. + + Args: + session: Database session + + Returns: + List of tuples containing (timestamp, pricing_model) + """ + pricing_elements = get_pricing_aggregate_history(session) + + timeline = [] + + # Add default pricing as the initial state + timeline.append((dt.datetime.min.replace(tzinfo=dt.timezone.utc), build_default_pricing_model())) + + # Build cumulative pricing models by merging elements up to each timestamp + elements_so_far = [] + for element in pricing_elements: + elements_so_far.append(element) + + # Merge all elements up to this point to get the cumulative state + merged_content = merge_aggregate_elements(elements_so_far) + + # Build pricing model from the merged content + pricing_model = build_pricing_model_from_aggregate(merged_content) + timeline.append((element.creation_datetime, pricing_model)) + + return timeline \ No newline at end of file diff --git a/src/aleph/toolkit/constants.py b/src/aleph/toolkit/constants.py index d33ccc1a..1517ec2c 100644 --- a/src/aleph/toolkit/constants.py +++ b/src/aleph/toolkit/constants.py @@ -5,11 +5,13 @@ MINUTE = 60 HOUR = 60 * MINUTE +from aleph.types.cost import ProductPriceType + PRICE_AGGREGATE_OWNER = "0xFba561a84A537fCaa567bb7A2257e7142701ae2A" PRICE_AGGREGATE_KEY = "pricing" PRICE_PRECISION = 18 DEFAULT_PRICE_AGGREGATE = { - "program": { + ProductPriceType.PROGRAM: { "price": { "storage": {"payg": "0.000000977", "holding": "0.05"}, "compute_unit": {"payg": "0.011", "holding": "200"}, @@ -28,8 +30,8 @@ "memory_mib": 2048, }, }, - "storage": {"price": {"storage": {"holding": "0.333333333"}}}, - "instance": { + ProductPriceType.STORAGE: {"price": {"storage": {"holding": "0.333333333"}}}, + ProductPriceType.INSTANCE: { "price": { "storage": {"payg": "0.000000977", "holding": "0.05"}, "compute_unit": {"payg": "0.055", "holding": "1000"}, @@ -48,8 +50,8 @@ "memory_mib": 2048, }, }, - "web3_hosting": {"price": {"fixed": 50, "storage": {"holding": "0.333333333"}}}, - "program_persistent": { + ProductPriceType.WEB3_HOSTING: {"price": {"fixed": 50, "storage": {"holding": "0.333333333"}}}, + ProductPriceType.PROGRAM_PERSISTENT: { "price": { "storage": {"payg": "0.000000977", "holding": "0.05"}, "compute_unit": {"payg": "0.055", "holding": "1000"}, @@ -68,7 +70,7 @@ "memory_mib": 2048, }, }, - "instance_gpu_premium": { + ProductPriceType.INSTANCE_GPU_PREMIUM: { "price": { "storage": {"payg": "0.000000977"}, "compute_unit": {"payg": "0.56"}, @@ -93,7 +95,7 @@ "memory_mib": 6144, }, }, - "instance_confidential": { + ProductPriceType.INSTANCE_CONFIDENTIAL: { "price": { "storage": {"payg": "0.000000977", "holding": "0.05"}, "compute_unit": {"payg": "0.11", "holding": "2000"}, @@ -112,7 +114,7 @@ "memory_mib": 2048, }, }, - "instance_gpu_standard": { + ProductPriceType.INSTANCE_GPU_STANDARD: { "price": { "storage": {"payg": "0.000000977"}, "compute_unit": {"payg": "0.28"}, diff --git a/src/aleph/web/controllers/prices.py b/src/aleph/web/controllers/prices.py index 48fc9df3..fe7bcfe6 100644 --- a/src/aleph/web/controllers/prices.py +++ b/src/aleph/web/controllers/prices.py @@ -1,15 +1,17 @@ import logging from dataclasses import dataclass from decimal import Decimal -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from aiohttp import web from aiohttp.web_exceptions import HTTPException from aleph_message.models import ExecutableContent, ItemHash, MessageType from dataclasses_json import DataClassJsonMixin from pydantic import BaseModel, Field +from sqlalchemy import select import aleph.toolkit.json as aleph_json +from aleph.db.accessors.cost import delete_costs_for_message, make_costs_upsert_query from aleph.db.accessors.messages import get_message_by_item_hash, get_message_status from aleph.db.models import MessageDb from aleph.schemas.api.costs import EstimatedCostsResponse @@ -18,10 +20,14 @@ validate_cost_estimation_message_dict, ) from aleph.services.cost import ( + _get_product_price_type, + _get_settings, + get_detailed_costs, get_payment_type, get_total_and_detailed_costs, get_total_and_detailed_costs_from_db, ) +from aleph.services.pricing_utils import get_pricing_timeline from aleph.toolkit.costs import format_cost_str from aleph.types.db_session import DbSession from aleph.types.message_status import MessageStatus @@ -164,3 +170,111 @@ async def message_price_estimate(request: web.Request): response = EstimatedCostsResponse.model_validate(model) return web.json_response(text=aleph_json.dumps(response).decode("utf-8")) + + +async def recalculate_message_costs(request: web.Request): + """Force recalculation of message costs in chronological order with historical pricing. + + This endpoint will: + 1. Get all messages that need cost recalculation (if item_hash provided, just that message) + 2. Get the pricing timeline to track price changes over time + 3. Sort messages chronologically (oldest first) + 4. For each message, use the pricing model that was active when the message was created + 5. Delete existing cost entries and recalculate with historical pricing + 6. Store the new cost calculations + """ + + session_factory = get_session_factory_from_request(request) + + # Check if a specific message hash was provided + item_hash_param = request.match_info.get("item_hash") + + with session_factory() as session: + messages_to_recalculate: List[MessageDb] = [] + + if item_hash_param: + # Recalculate costs for a specific message + try: + message = await get_executable_message(session, item_hash_param) + messages_to_recalculate = [message] + except HTTPException: + raise + else: + # Recalculate costs for all executable messages, ordered by time (oldest first) + select_stmt = ( + select(MessageDb) + .where(MessageDb.type.in_([MessageType.instance, MessageType.program, MessageType.store])) + .order_by(MessageDb.time.asc()) + ) + result = session.execute(select_stmt) + messages_to_recalculate = result.scalars().all() + + if not messages_to_recalculate: + return web.json_response( + {"message": "No messages found for cost recalculation", "recalculated_count": 0} + ) + + # Get the pricing timeline to track price changes over time + pricing_timeline = get_pricing_timeline(session) + LOGGER.info(f"Found {len(pricing_timeline)} pricing changes in timeline") + + recalculated_count = 0 + errors = [] + current_pricing_model = None + current_pricing_index = 0 + + for message in messages_to_recalculate: + try: + # Find the applicable pricing model for this message's timestamp + while (current_pricing_index < len(pricing_timeline) - 1 and + pricing_timeline[current_pricing_index + 1][0] <= message.time): + current_pricing_index += 1 + + current_pricing_model = pricing_timeline[current_pricing_index][1] + pricing_timestamp = pricing_timeline[current_pricing_index][0] + + LOGGER.debug(f"Message {message.item_hash} at {message.time} using pricing from {pricing_timestamp}") + + # Delete existing cost entries for this message + delete_costs_for_message(session, message.item_hash) + + # Get the message content and determine product type + content: ExecutableContent = message.parsed_content + product_type = _get_product_price_type(content, None, current_pricing_model) + + # Get the pricing for this specific product type + if product_type not in current_pricing_model: + LOGGER.warning(f"Product type {product_type} not found in pricing model for message {message.item_hash}") + continue + + pricing = current_pricing_model[product_type] + + # Calculate new costs using the historical pricing model + new_costs = get_detailed_costs(session, content, message.item_hash, pricing) + + if new_costs: + # Store the new cost calculations + upsert_stmt = make_costs_upsert_query(new_costs) + session.execute(upsert_stmt) + + recalculated_count += 1 + + except Exception as e: + error_msg = f"Failed to recalculate costs for message {message.item_hash}: {str(e)}" + LOGGER.error(error_msg) + errors.append({"item_hash": message.item_hash, "error": str(e)}) + + # Commit all changes + session.commit() + + response_data = { + "message": "Cost recalculation completed with historical pricing", + "recalculated_count": recalculated_count, + "total_messages": len(messages_to_recalculate), + "pricing_changes_found": len(pricing_timeline) + } + + if errors: + response_data["errors"] = errors + + return web.json_response(response_data) diff --git a/src/aleph/web/controllers/routes.py b/src/aleph/web/controllers/routes.py index b4678fdf..297bf940 100644 --- a/src/aleph/web/controllers/routes.py +++ b/src/aleph/web/controllers/routes.py @@ -67,6 +67,8 @@ def register_routes(app: web.Application): app.router.add_get("/api/v0/price/{item_hash}", prices.message_price) app.router.add_post("/api/v0/price/estimate", prices.message_price_estimate) + app.router.add_post("/api/v0/price/recalculate", prices.recalculate_message_costs) + app.router.add_post("/api/v0/price/{item_hash}/recalculate", prices.recalculate_message_costs) app.router.add_get("/api/v0/addresses/stats.json", accounts.addresses_stats_view) app.router.add_get( diff --git a/tests/api/test_pricing_recalculation.py b/tests/api/test_pricing_recalculation.py new file mode 100644 index 00000000..44e4d1c0 --- /dev/null +++ b/tests/api/test_pricing_recalculation.py @@ -0,0 +1,410 @@ +import datetime as dt +import json +from decimal import Decimal +from unittest.mock import patch + +import pytest +from aiohttp import web +from aleph_message.models import MessageType + +from aleph.db.models import MessageDb +from aleph.db.models.account_costs import AccountCostsDb +from aleph.db.models.aggregates import AggregateElementDb +from aleph.toolkit.constants import ( + PRICE_AGGREGATE_KEY, + PRICE_AGGREGATE_OWNER, +) +from aleph.types.cost import ProductPriceType +from aleph.types.message_status import MessageStatus +from aleph.web.controllers.prices import recalculate_message_costs + + +@pytest.fixture +def sample_messages(session_factory): + """Create sample messages for testing cost recalculation.""" + base_time = dt.datetime(2024, 1, 1, 10, 0, 0, tzinfo=dt.timezone.utc) + + # Create sample instance message + instance_message = MessageDb( + item_hash="instance_msg_1", + type=MessageType.instance, + chain="ETH", + sender="0xTest1", + item_type="inline", + content={ + "time": (base_time + dt.timedelta(hours=1)).timestamp(), + "rootfs": { + "parent": {"ref": "test_ref", "use_latest": True}, + "size_mib": 20480, + "persistence": "host", + }, + "address": "0xTest1", + "volumes": [], + "metadata": {"name": "Test Instance"}, + "resources": {"vcpus": 1, "memory": 2048, "seconds": 30}, + "allow_amend": False, + "environment": {"internet": True, "aleph_api": True}, + }, + time=base_time + dt.timedelta(hours=1), + size=1024, + ) + + # Create sample program message + program_message = MessageDb( + item_hash="program_msg_1", + type=MessageType.program, + chain="ETH", + sender="0xTest2", + item_type="inline", + content={ + "time": (base_time + dt.timedelta(hours=2)).timestamp(), + "on": {"http": True, "persistent": False}, + "code": {"ref": "code_ref", "encoding": "zip", "entrypoint": "main:app", "use_latest": True}, + "runtime": {"ref": "runtime_ref", "use_latest": True}, + "address": "0xTest2", + "resources": {"vcpus": 1, "memory": 128, "seconds": 30}, + "allow_amend": False, + "environment": {"internet": True, "aleph_api": True}, + }, + time=base_time + dt.timedelta(hours=2), + size=512, + ) + + # Create sample store message + store_message = MessageDb( + item_hash="store_msg_1", + type=MessageType.store, + chain="ETH", + sender="0xTest3", + item_type="inline", + content={ + "time": (base_time + dt.timedelta(hours=3)).timestamp(), + "item_type": "storage", + "item_hash": "stored_file_hash", + "address": "0xTest3", + }, + time=base_time + dt.timedelta(hours=3), + size=2048, + ) + + with session_factory() as session: + session.add(instance_message) + session.add(program_message) + session.add(store_message) + session.commit() + session.refresh(instance_message) + session.refresh(program_message) + session.refresh(store_message) + + return [instance_message, program_message, store_message] + + +@pytest.fixture +def pricing_updates_with_timeline(session_factory): + """Create pricing updates that form a timeline for testing.""" + base_time = dt.datetime(2024, 1, 1, 9, 0, 0, tzinfo=dt.timezone.utc) + + # First pricing update - before any messages + element1 = AggregateElementDb( + item_hash="pricing_1", + key=PRICE_AGGREGATE_KEY, + owner=PRICE_AGGREGATE_OWNER, + content={ + ProductPriceType.STORAGE: { + "price": {"storage": {"holding": "0.1"}} + }, + ProductPriceType.INSTANCE: { + "price": { + "storage": {"holding": "0.05"}, + "compute_unit": {"holding": "500"}, + }, + "compute_unit": {"vcpus": 1, "disk_mib": 20480, "memory_mib": 2048}, + } + }, + creation_datetime=base_time + dt.timedelta(minutes=30) + ) + + # Second pricing update - between instance and program messages + element2 = AggregateElementDb( + item_hash="pricing_2", + key=PRICE_AGGREGATE_KEY, + owner=PRICE_AGGREGATE_OWNER, + content={ + ProductPriceType.PROGRAM: { + "price": { + "storage": {"holding": "0.03"}, + "compute_unit": {"holding": "150"}, + }, + "compute_unit": {"vcpus": 1, "disk_mib": 2048, "memory_mib": 2048}, + } + }, + creation_datetime=base_time + dt.timedelta(hours=1, minutes=30) + ) + + # Third pricing update - after program but before store message + element3 = AggregateElementDb( + item_hash="pricing_3", + key=PRICE_AGGREGATE_KEY, + owner=PRICE_AGGREGATE_OWNER, + content={ + ProductPriceType.STORAGE: { + "price": {"storage": {"holding": "0.2"}} # Updated storage price + } + }, + creation_datetime=base_time + dt.timedelta(hours=2, minutes=30) + ) + + with session_factory() as session: + session.add(element1) + session.add(element2) + session.add(element3) + session.commit() + session.refresh(element1) + session.refresh(element2) + session.refresh(element3) + + return [element1, element2, element3] + + +@pytest.fixture +def existing_costs(session_factory, sample_messages): + """Create some existing cost entries to test deletion and recalculation.""" + costs = [] + + for message in sample_messages: + cost = AccountCostsDb( + owner=message.sender, + item_hash=message.item_hash, + type="EXECUTION", + name="old_cost", + payment_type="hold", + cost_hold=Decimal("999.99"), # Old/incorrect cost + cost_stream=Decimal("0.01"), + ) + costs.append(cost) + + with session_factory() as session: + for cost in costs: + session.add(cost) + session.commit() + + return costs + + +class TestRecalculateMessageCosts: + """Tests for the message cost recalculation endpoint.""" + + @pytest.fixture + def mock_request_factory(self, session_factory): + """Factory to create mock requests.""" + def _create_mock_request(match_info=None): + request = web.Request.__new__(web.Request) + request._match_info = match_info or {} + + # Mock the session factory getter + def get_session_factory(): + return session_factory + + request._session_factory = get_session_factory + return request + + return _create_mock_request + + @patch('aleph.web.controllers.prices.get_session_factory_from_request') + async def test_recalculate_all_messages_empty_db(self, mock_get_session, session_factory, mock_request_factory): + """Test recalculation when no messages exist.""" + mock_get_session.return_value = session_factory + request = mock_request_factory() + + response = await recalculate_message_costs(request) + + assert response.status == 200 + response_data = json.loads(response.text) + assert response_data["recalculated_count"] == 0 + assert response_data["total_messages"] == 0 + assert "No messages found" in response_data["message"] + + @patch('aleph.web.controllers.prices.get_session_factory_from_request') + @patch('aleph.web.controllers.prices.get_executable_message') + async def test_recalculate_specific_message(self, mock_get_executable, mock_get_session, session_factory, sample_messages, mock_request_factory): + """Test recalculation of a specific message.""" + mock_get_session.return_value = session_factory + mock_get_executable.return_value = sample_messages[0] # Return first message + + request = mock_request_factory({"item_hash": "instance_msg_1"}) + + with patch('aleph.web.controllers.prices.get_detailed_costs') as mock_get_costs: + mock_get_costs.return_value = [] # Mock empty costs + + response = await recalculate_message_costs(request) + + assert response.status == 200 + response_data = json.loads(response.text) + assert response_data["recalculated_count"] == 1 + assert response_data["total_messages"] == 1 + assert "historical pricing" in response_data["message"] + + # Should have called get_detailed_costs once + assert mock_get_costs.call_count == 1 + + @patch('aleph.web.controllers.prices.get_session_factory_from_request') + async def test_recalculate_all_messages_with_timeline(self, mock_get_session, session_factory, sample_messages, pricing_updates_with_timeline, existing_costs, mock_request_factory): + """Test recalculation of all messages with pricing timeline.""" + mock_get_session.return_value = session_factory + + request = mock_request_factory() + + with patch('aleph.web.controllers.prices.get_detailed_costs') as mock_get_costs: + mock_get_costs.return_value = [] # Mock empty costs + + response = await recalculate_message_costs(request) + + assert response.status == 200 + response_data = json.loads(response.text) + assert response_data["recalculated_count"] == 3 + assert response_data["total_messages"] == 3 + assert response_data["pricing_changes_found"] == 4 # Default + 3 updates + + # Should have called get_detailed_costs for each message + assert mock_get_costs.call_count == 3 + + # Verify old costs were deleted + with session_factory() as session: + remaining_costs = session.query(AccountCostsDb).all() + assert len(remaining_costs) == 0 # All old costs should be deleted + + @patch('aleph.web.controllers.prices.get_session_factory_from_request') + async def test_recalculate_with_pricing_timeline_application(self, mock_get_session, session_factory, sample_messages, pricing_updates_with_timeline, mock_request_factory): + """Test that the correct pricing model is applied based on message timestamps.""" + mock_get_session.return_value = session_factory + + request = mock_request_factory() + + pricing_calls = [] + + def mock_get_costs(session, content, item_hash, pricing): + # Capture the pricing object used for each call + pricing_calls.append((item_hash, pricing.type if pricing else None)) + return [] + + with patch('aleph.web.controllers.prices.get_detailed_costs', side_effect=mock_get_costs): + response = await recalculate_message_costs(request) + + assert response.status == 200 + + # Should have made calls for all 3 messages + assert len(pricing_calls) == 3 + + # Verify the correct pricing types were used (based on message content and timeline) + item_hashes = [call[0] for call in pricing_calls] + assert "instance_msg_1" in item_hashes + assert "program_msg_1" in item_hashes + assert "store_msg_1" in item_hashes + + @patch('aleph.web.controllers.prices.get_session_factory_from_request') + async def test_recalculate_with_errors(self, mock_get_session, session_factory, sample_messages, mock_request_factory): + """Test recalculation handling of errors.""" + mock_get_session.return_value = session_factory + + request = mock_request_factory() + + def mock_get_costs_with_error(session, content, item_hash, pricing): + if item_hash == "program_msg_1": + raise ValueError("Test error for program message") + return [] + + with patch('aleph.web.controllers.prices.get_detailed_costs', side_effect=mock_get_costs_with_error): + response = await recalculate_message_costs(request) + + assert response.status == 200 + response_data = json.loads(response.text) + + # Should have processed 2 successfully, 1 with error + assert response_data["recalculated_count"] == 2 + assert response_data["total_messages"] == 3 + assert "errors" in response_data + assert len(response_data["errors"]) == 1 + assert response_data["errors"][0]["item_hash"] == "program_msg_1" + assert "Test error" in response_data["errors"][0]["error"] + + @patch('aleph.web.controllers.prices.get_session_factory_from_request') + @patch('aleph.web.controllers.prices.get_executable_message') + async def test_recalculate_specific_message_not_found(self, mock_get_executable, mock_get_session, session_factory, mock_request_factory): + """Test recalculation of a specific message that doesn't exist.""" + mock_get_session.return_value = session_factory + mock_get_executable.side_effect = web.HTTPNotFound(body="Message not found") + + request = mock_request_factory({"item_hash": "nonexistent_hash"}) + + with pytest.raises(web.HTTPNotFound): + await recalculate_message_costs(request) + + @patch('aleph.web.controllers.prices.get_session_factory_from_request') + async def test_chronological_processing_order(self, mock_get_session, session_factory, sample_messages, mock_request_factory): + """Test that messages are processed in chronological order.""" + mock_get_session.return_value = session_factory + + request = mock_request_factory() + + processed_order = [] + + def mock_get_costs(session, content, item_hash, pricing): + processed_order.append(item_hash) + return [] + + with patch('aleph.web.controllers.prices.get_detailed_costs', side_effect=mock_get_costs): + response = await recalculate_message_costs(request) + + assert response.status == 200 + + # Should have processed in chronological order based on message.time + expected_order = ["instance_msg_1", "program_msg_1", "store_msg_1"] + assert processed_order == expected_order + + +class TestPricingTimelineIntegration: + """Integration tests for the complete pricing timeline feature.""" + + @pytest.fixture + def mock_request_factory(self, session_factory): + """Factory to create mock requests.""" + def _create_mock_request(match_info=None): + request = web.Request.__new__(web.Request) + request._match_info = match_info or {} + + # Mock the session factory getter + def get_session_factory(): + return session_factory + + request._session_factory = get_session_factory + return request + + return _create_mock_request + + @patch('aleph.web.controllers.prices.get_session_factory_from_request') + async def test_end_to_end_historical_pricing(self, mock_get_session, session_factory, sample_messages, pricing_updates_with_timeline, mock_request_factory): + """End-to-end test of historical pricing application.""" + mock_get_session.return_value = session_factory + + request = mock_request_factory() + + # Track which pricing models are used for each message + pricing_usage = {} + + def mock_get_costs(session, content, item_hash, pricing): + if pricing and hasattr(pricing, 'price'): + if hasattr(pricing.price, 'storage') and hasattr(pricing.price.storage, 'holding'): + pricing_usage[item_hash] = float(pricing.price.storage.holding) + return [] + + with patch('aleph.web.controllers.prices.get_detailed_costs', side_effect=mock_get_costs): + response = await recalculate_message_costs(request) + + assert response.status == 200 + response_data = json.loads(response.text) + assert response_data["recalculated_count"] == 3 + + # Verify that different pricing was applied based on timeline + # The exact values depend on the pricing timeline and merge logic, + # but we can verify that historical pricing was considered + assert len(pricing_usage) > 0 \ No newline at end of file diff --git a/tests/services/test_pricing_utils.py b/tests/services/test_pricing_utils.py new file mode 100644 index 00000000..1cade1ff --- /dev/null +++ b/tests/services/test_pricing_utils.py @@ -0,0 +1,401 @@ +import datetime as dt +from decimal import Decimal +from unittest.mock import patch + +import pytest +from aleph_message.models import InstanceContent, ProgramContent, StoreContent + +from aleph.db.models.aggregates import AggregateElementDb +from aleph.services.pricing_utils import ( + build_default_pricing_model, + build_pricing_model_from_aggregate, + get_pricing_timeline, + get_pricing_aggregate_history, +) +from aleph.toolkit.constants import ( + PRICE_AGGREGATE_KEY, + PRICE_AGGREGATE_OWNER, +) +from aleph.types.cost import ProductPriceType, ProductPricing +from aleph.types.db_session import DbSessionFactory + + +@pytest.fixture +def sample_pricing_aggregate_content(): + """Sample pricing aggregate content with ProductPriceType keys.""" + return { + ProductPriceType.STORAGE: { + "price": {"storage": {"holding": "0.5"}} + }, + ProductPriceType.PROGRAM: { + "price": { + "storage": {"payg": "0.000001", "holding": "0.1"}, + "compute_unit": {"payg": "0.02", "holding": "300"}, + }, + "compute_unit": { + "vcpus": 1, + "disk_mib": 2048, + "memory_mib": 2048, + }, + }, + ProductPriceType.INSTANCE: { + "price": { + "storage": {"payg": "0.000001", "holding": "0.1"}, + "compute_unit": {"payg": "0.1", "holding": "1500"}, + }, + "compute_unit": { + "vcpus": 1, + "disk_mib": 20480, + "memory_mib": 2048, + }, + }, + } + + +@pytest.fixture +def pricing_aggregate_elements(session_factory): + """Create sample pricing aggregate elements for timeline testing.""" + base_time = dt.datetime(2024, 1, 1, 12, 0, 0, tzinfo=dt.timezone.utc) + + # First pricing update - only storage pricing + element1 = AggregateElementDb( + item_hash="pricing_update_1", + key=PRICE_AGGREGATE_KEY, + owner=PRICE_AGGREGATE_OWNER, + content={ + ProductPriceType.STORAGE: { + "price": {"storage": {"holding": "0.2"}} + } + }, + creation_datetime=base_time + dt.timedelta(hours=1) + ) + + # Second pricing update - adds program pricing + element2 = AggregateElementDb( + item_hash="pricing_update_2", + key=PRICE_AGGREGATE_KEY, + owner=PRICE_AGGREGATE_OWNER, + content={ + ProductPriceType.PROGRAM: { + "price": { + "storage": {"payg": "0.000001", "holding": "0.08"}, + "compute_unit": {"payg": "0.015", "holding": "250"}, + }, + "compute_unit": { + "vcpus": 1, + "disk_mib": 2048, + "memory_mib": 2048, + }, + } + }, + creation_datetime=base_time + dt.timedelta(hours=2) + ) + + # Third pricing update - updates storage and adds instance pricing + element3 = AggregateElementDb( + item_hash="pricing_update_3", + key=PRICE_AGGREGATE_KEY, + owner=PRICE_AGGREGATE_OWNER, + content={ + ProductPriceType.STORAGE: { + "price": {"storage": {"holding": "0.3"}} + }, + ProductPriceType.INSTANCE: { + "price": { + "storage": {"payg": "0.000001", "holding": "0.05"}, + "compute_unit": {"payg": "0.06", "holding": "1200"}, + }, + "compute_unit": { + "vcpus": 1, + "disk_mib": 20480, + "memory_mib": 2048, + }, + } + }, + creation_datetime=base_time + dt.timedelta(hours=3) + ) + + with session_factory() as session: + session.add(element1) + session.add(element2) + session.add(element3) + session.commit() + session.refresh(element1) + session.refresh(element2) + session.refresh(element3) + + return [element1, element2, element3] + + +class TestBuildPricingModelFromAggregate: + """Tests for building pricing models from aggregate content.""" + + def test_build_pricing_model_from_aggregate(self, sample_pricing_aggregate_content): + """Test building pricing model from aggregate content.""" + pricing_model = build_pricing_model_from_aggregate(sample_pricing_aggregate_content) + + # Check that we got ProductPricing objects for each type + assert len(pricing_model) == 3 + assert ProductPriceType.STORAGE in pricing_model + assert ProductPriceType.PROGRAM in pricing_model + assert ProductPriceType.INSTANCE in pricing_model + + # Check that each is a ProductPricing object + for price_type, pricing in pricing_model.items(): + assert isinstance(pricing, ProductPricing) + assert pricing.type == price_type + + def test_build_pricing_model_with_missing_types(self): + """Test building pricing model with some missing product types.""" + partial_content = { + ProductPriceType.STORAGE: { + "price": {"storage": {"holding": "0.5"}} + } + } + + pricing_model = build_pricing_model_from_aggregate(partial_content) + + assert len(pricing_model) == 1 + assert ProductPriceType.STORAGE in pricing_model + assert ProductPriceType.PROGRAM not in pricing_model + + def test_build_pricing_model_with_invalid_data(self): + """Test building pricing model with invalid pricing data.""" + invalid_content = { + ProductPriceType.STORAGE: { + "invalid": "data" # Missing required "price" key + } + } + + # Should handle the error gracefully and return empty model + pricing_model = build_pricing_model_from_aggregate(invalid_content) + assert len(pricing_model) == 0 + + +class TestBuildDefaultPricingModel: + """Tests for building the default pricing model.""" + + def test_build_default_pricing_model(self): + """Test building the default pricing model from constants.""" + pricing_model = build_default_pricing_model() + + # Should contain all the expected product types from DEFAULT_PRICE_AGGREGATE + expected_types = [ + ProductPriceType.PROGRAM, + ProductPriceType.STORAGE, + ProductPriceType.INSTANCE, + ProductPriceType.PROGRAM_PERSISTENT, + ProductPriceType.INSTANCE_GPU_PREMIUM, + ProductPriceType.INSTANCE_CONFIDENTIAL, + ProductPriceType.INSTANCE_GPU_STANDARD, + ProductPriceType.WEB3_HOSTING, + ] + + for price_type in expected_types: + assert price_type in pricing_model + assert isinstance(pricing_model[price_type], ProductPricing) + + +class TestGetPricingAggregateHistory: + """Tests for retrieving pricing aggregate history.""" + + def test_get_pricing_aggregate_history_empty(self, session_factory): + """Test getting pricing history when no elements exist.""" + with session_factory() as session: + history = get_pricing_aggregate_history(session) + assert len(history) == 0 + + def test_get_pricing_aggregate_history_with_elements(self, session_factory, pricing_aggregate_elements): + """Test getting pricing history with existing elements.""" + with session_factory() as session: + history = get_pricing_aggregate_history(session) + + assert len(history) == 3 + + # Should be ordered chronologically + assert history[0].creation_datetime < history[1].creation_datetime + assert history[1].creation_datetime < history[2].creation_datetime + + # Check content + assert ProductPriceType.STORAGE in history[0].content + assert ProductPriceType.PROGRAM in history[1].content + assert ProductPriceType.STORAGE in history[2].content + assert ProductPriceType.INSTANCE in history[2].content + + +class TestGetPricingTimeline: + """Tests for getting the pricing timeline.""" + + def test_get_pricing_timeline_empty(self, session_factory): + """Test getting pricing timeline when no aggregate elements exist.""" + with session_factory() as session: + timeline = get_pricing_timeline(session) + + # Should have at least the default pricing + assert len(timeline) == 1 + timestamp, pricing_model = timeline[0] + + # Should use minimum datetime for default pricing + assert timestamp == dt.datetime.min.replace(tzinfo=dt.timezone.utc) + assert isinstance(pricing_model, dict) + assert ProductPriceType.STORAGE in pricing_model + + def test_get_pricing_timeline_with_elements(self, session_factory, pricing_aggregate_elements): + """Test getting pricing timeline with aggregate elements.""" + with session_factory() as session: + timeline = get_pricing_timeline(session) + + # Should have default + 3 pricing updates + assert len(timeline) == 4 + + # Check chronological order + for i in range(len(timeline) - 1): + assert timeline[i][0] <= timeline[i + 1][0] + + # Check content evolution + default_timestamp, default_model = timeline[0] + first_timestamp, first_model = timeline[1] + second_timestamp, second_model = timeline[2] + third_timestamp, third_model = timeline[3] + + # First update: only storage pricing + assert ProductPriceType.STORAGE in first_model + storage_pricing_1 = first_model[ProductPriceType.STORAGE] + assert storage_pricing_1.price.storage.holding == Decimal("0.2") + + # Second update: storage + program pricing (cumulative) + assert ProductPriceType.STORAGE in second_model + assert ProductPriceType.PROGRAM in second_model + storage_pricing_2 = second_model[ProductPriceType.STORAGE] + assert storage_pricing_2.price.storage.holding == Decimal("0.2") # Still from first update + + # Third update: updated storage + program + instance pricing (cumulative) + assert ProductPriceType.STORAGE in third_model + assert ProductPriceType.PROGRAM in third_model + assert ProductPriceType.INSTANCE in third_model + storage_pricing_3 = third_model[ProductPriceType.STORAGE] + assert storage_pricing_3.price.storage.holding == Decimal("0.3") # Updated value + + def test_pricing_timeline_cumulative_merging(self, session_factory): + """Test that pricing timeline properly merges cumulative changes.""" + base_time = dt.datetime(2024, 1, 1, 12, 0, 0, tzinfo=dt.timezone.utc) + + # Create elements that update different parts of pricing + element1 = AggregateElementDb( + item_hash="test_1", + key=PRICE_AGGREGATE_KEY, + owner=PRICE_AGGREGATE_OWNER, + content={ + ProductPriceType.STORAGE: { + "price": {"storage": {"holding": "1.0"}} + }, + ProductPriceType.PROGRAM: { + "price": { + "storage": {"holding": "0.1"}, + "compute_unit": {"holding": "100"}, + }, + "compute_unit": {"vcpus": 1, "disk_mib": 1024, "memory_mib": 1024}, + } + }, + creation_datetime=base_time + dt.timedelta(hours=1) + ) + + # Second element only updates storage, should preserve program settings + element2 = AggregateElementDb( + item_hash="test_2", + key=PRICE_AGGREGATE_KEY, + owner=PRICE_AGGREGATE_OWNER, + content={ + ProductPriceType.STORAGE: { + "price": {"storage": {"holding": "2.0"}} # Updated price + } + }, + creation_datetime=base_time + dt.timedelta(hours=2) + ) + + with session_factory() as session: + session.add(element1) + session.add(element2) + session.commit() + + timeline = get_pricing_timeline(session) + + # Should have default + 2 updates + assert len(timeline) == 3 + + # Check final state has both storage and program, with updated storage price + final_timestamp, final_model = timeline[2] + + assert ProductPriceType.STORAGE in final_model + assert ProductPriceType.PROGRAM in final_model + + # Storage should have updated price + storage_pricing = final_model[ProductPriceType.STORAGE] + assert storage_pricing.price.storage.holding == Decimal("2.0") + + # Program should still have original settings + program_pricing = final_model[ProductPriceType.PROGRAM] + assert program_pricing.price.storage.holding == Decimal("0.1") + assert program_pricing.price.compute_unit.holding == Decimal("100") + + +class TestPricingTimelineIntegration: + """Integration tests for pricing timeline with real message types.""" + + @pytest.fixture + def sample_instance_content(self): + """Sample instance content for testing.""" + return InstanceContent.model_validate({ + "time": 1701099523.849, + "rootfs": { + "parent": { + "ref": "549ec451d9b099cad112d4aaa2c00ac40fb6729a92ff252ff22eef0b5c3cb613", + "use_latest": True, + }, + "size_mib": 20480, + "persistence": "host", + }, + "address": "0xTest", + "volumes": [], + "metadata": {"name": "Test Instance"}, + "resources": {"vcpus": 1, "memory": 2048, "seconds": 30}, + "allow_amend": False, + "environment": {"internet": True, "aleph_api": True}, + }) + + def test_pricing_timeline_with_message_processing(self, session_factory, pricing_aggregate_elements, sample_instance_content): + """Test that pricing timeline can be used for historical message cost calculation.""" + with session_factory() as session: + timeline = get_pricing_timeline(session) + + # Simulate processing a message at different points in time + message_time_1 = dt.datetime(2024, 1, 1, 13, 30, 0, tzinfo=dt.timezone.utc) # Between first and second update + message_time_2 = dt.datetime(2024, 1, 1, 15, 30, 0, tzinfo=dt.timezone.utc) # After all updates + + # Find applicable pricing for each message time + pricing_1 = None + pricing_2 = None + + for timestamp, pricing_model in timeline: + if timestamp <= message_time_1: + pricing_1 = pricing_model + if timestamp <= message_time_2: + pricing_2 = pricing_model + + # At time 1, should have storage pricing but not instance pricing + assert pricing_1 is not None + assert ProductPriceType.STORAGE in pricing_1 + # Instance pricing not added until third update + assert ProductPriceType.INSTANCE not in pricing_1 + + # At time 2, should have both storage and instance pricing + assert pricing_2 is not None + assert ProductPriceType.STORAGE in pricing_2 + assert ProductPriceType.INSTANCE in pricing_2 + + # Storage pricing should be different between the two time points + if ProductPriceType.STORAGE in pricing_1 and ProductPriceType.STORAGE in pricing_2: + storage_1 = pricing_1[ProductPriceType.STORAGE] + storage_2 = pricing_2[ProductPriceType.STORAGE] + # Should have different prices due to the third update + assert storage_1.price.storage.holding != storage_2.price.storage.holding \ No newline at end of file From 821a270476734f73d053fe08982e687ba619d6a1 Mon Sep 17 00:00:00 2001 From: amalcaraz Date: Tue, 24 Jun 2025 17:11:29 +0200 Subject: [PATCH 2/2] fix: lint --- ...1c06d0ade60c_calculate_costs_statically.py | 3 +- src/aleph/services/pricing_utils.py | 66 +++-- src/aleph/toolkit/constants.py | 12 +- src/aleph/web/controllers/prices.py | 83 ++++-- src/aleph/web/controllers/routes.py | 4 +- tests/api/test_pricing_recalculation.py | 274 +++++++++++------- tests/services/test_pricing_utils.py | 222 +++++++------- 7 files changed, 379 insertions(+), 285 deletions(-) diff --git a/deployment/migrations/versions/0033_1c06d0ade60c_calculate_costs_statically.py b/deployment/migrations/versions/0033_1c06d0ade60c_calculate_costs_statically.py index ea44cab1..6419c90d 100644 --- a/deployment/migrations/versions/0033_1c06d0ade60c_calculate_costs_statically.py +++ b/deployment/migrations/versions/0033_1c06d0ade60c_calculate_costs_statically.py @@ -7,7 +7,6 @@ """ from decimal import Decimal -from typing import Dict from alembic import op import sqlalchemy as sa import logging @@ -18,7 +17,7 @@ from aleph.db.accessors.messages import get_message_by_item_hash from aleph.services.cost import _is_confidential_vm, get_detailed_costs, CostComputableContent from aleph.services.pricing_utils import build_default_pricing_model -from aleph.types.cost import ProductPriceType, ProductPricing +from aleph.types.cost import ProductPriceType from aleph.types.db_session import DbSession logger = logging.getLogger("alembic") diff --git a/src/aleph/services/pricing_utils.py b/src/aleph/services/pricing_utils.py index 0e5f873b..11c1975c 100644 --- a/src/aleph/services/pricing_utils.py +++ b/src/aleph/services/pricing_utils.py @@ -5,7 +5,10 @@ import datetime as dt from typing import Dict, List, Union -from aleph.db.accessors.aggregates import get_aggregate_elements, merge_aggregate_elements +from aleph.db.accessors.aggregates import ( + get_aggregate_elements, + merge_aggregate_elements, +) from aleph.db.models import AggregateElementDb from aleph.toolkit.constants import ( DEFAULT_PRICE_AGGREGATE, @@ -16,22 +19,24 @@ from aleph.types.db_session import DbSession -def build_pricing_model_from_aggregate(aggregate_content: Dict[Union[ProductPriceType, str], dict]) -> Dict[ProductPriceType, ProductPricing]: +def build_pricing_model_from_aggregate( + aggregate_content: Dict[Union[ProductPriceType, str], dict] +) -> Dict[ProductPriceType, ProductPricing]: """ Build a complete pricing model from an aggregate content dictionary. - - This function converts the DEFAULT_PRICE_AGGREGATE format or any pricing aggregate - content into a dictionary of ProductPricing objects that can be used by the cost + + This function converts the DEFAULT_PRICE_AGGREGATE format or any pricing aggregate + content into a dictionary of ProductPricing objects that can be used by the cost calculation functions. - + Args: aggregate_content: Dictionary containing pricing information with ProductPriceType as keys - + Returns: Dictionary mapping ProductPriceType to ProductPricing objects """ pricing_model: Dict[ProductPriceType, ProductPricing] = {} - + for price_type, pricing_data in aggregate_content.items(): try: price_type = ProductPriceType(price_type) @@ -41,16 +46,17 @@ def build_pricing_model_from_aggregate(aggregate_content: Dict[Union[ProductPric except (KeyError, ValueError) as e: # Log the error but continue processing other price types import logging + logger = logging.getLogger(__name__) logger.warning(f"Failed to parse pricing for {price_type}: {e}") - + return pricing_model def build_default_pricing_model() -> Dict[ProductPriceType, ProductPricing]: """ Build the default pricing model from DEFAULT_PRICE_AGGREGATE constant. - + Returns: Dictionary mapping ProductPriceType to ProductPricing objects """ @@ -60,55 +66,57 @@ def build_default_pricing_model() -> Dict[ProductPriceType, ProductPricing]: def get_pricing_aggregate_history(session: DbSession) -> List[AggregateElementDb]: """ Get all pricing aggregate updates in chronological order. - + Args: session: Database session - + Returns: List of AggregateElementDb objects ordered by creation_datetime """ aggregate_elements = get_aggregate_elements( - session=session, - owner=PRICE_AGGREGATE_OWNER, - key=PRICE_AGGREGATE_KEY + session=session, owner=PRICE_AGGREGATE_OWNER, key=PRICE_AGGREGATE_KEY ) return list(aggregate_elements) -def get_pricing_timeline(session: DbSession) -> List[tuple[dt.datetime, Dict[ProductPriceType, ProductPricing]]]: +def get_pricing_timeline( + session: DbSession, +) -> List[tuple[dt.datetime, Dict[ProductPriceType, ProductPricing]]]: """ Get the complete pricing timeline with timestamps and pricing models. - + This function returns a chronologically ordered list of pricing changes, - useful for processing messages in chronological order and applying the + useful for processing messages in chronological order and applying the correct pricing at each point in time. - + This properly merges aggregate elements up to each point in time to create the cumulative pricing state, similar to how _update_aggregate works. - + Args: session: Database session - + Returns: List of tuples containing (timestamp, pricing_model) """ pricing_elements = get_pricing_aggregate_history(session) - + timeline = [] - + # Add default pricing as the initial state - timeline.append((dt.datetime.min.replace(tzinfo=dt.timezone.utc), build_default_pricing_model())) - + timeline.append( + (dt.datetime.min.replace(tzinfo=dt.timezone.utc), build_default_pricing_model()) + ) + # Build cumulative pricing models by merging elements up to each timestamp elements_so_far = [] for element in pricing_elements: elements_so_far.append(element) - + # Merge all elements up to this point to get the cumulative state merged_content = merge_aggregate_elements(elements_so_far) - + # Build pricing model from the merged content pricing_model = build_pricing_model_from_aggregate(merged_content) timeline.append((element.creation_datetime, pricing_model)) - - return timeline \ No newline at end of file + + return timeline diff --git a/src/aleph/toolkit/constants.py b/src/aleph/toolkit/constants.py index 1517ec2c..129df7d3 100644 --- a/src/aleph/toolkit/constants.py +++ b/src/aleph/toolkit/constants.py @@ -1,3 +1,7 @@ +from typing import Dict, Union + +from aleph.types.cost import ProductPriceType + KiB = 1024 MiB = 1024 * 1024 GiB = 1024 * 1024 * 1024 @@ -5,12 +9,10 @@ MINUTE = 60 HOUR = 60 * MINUTE -from aleph.types.cost import ProductPriceType - PRICE_AGGREGATE_OWNER = "0xFba561a84A537fCaa567bb7A2257e7142701ae2A" PRICE_AGGREGATE_KEY = "pricing" PRICE_PRECISION = 18 -DEFAULT_PRICE_AGGREGATE = { +DEFAULT_PRICE_AGGREGATE: Dict[Union[ProductPriceType, str], dict] = { ProductPriceType.PROGRAM: { "price": { "storage": {"payg": "0.000000977", "holding": "0.05"}, @@ -50,7 +52,9 @@ "memory_mib": 2048, }, }, - ProductPriceType.WEB3_HOSTING: {"price": {"fixed": 50, "storage": {"holding": "0.333333333"}}}, + ProductPriceType.WEB3_HOSTING: { + "price": {"fixed": 50, "storage": {"holding": "0.333333333"}} + }, ProductPriceType.PROGRAM_PERSISTENT: { "price": { "storage": {"payg": "0.000000977", "holding": "0.05"}, diff --git a/src/aleph/web/controllers/prices.py b/src/aleph/web/controllers/prices.py index fe7bcfe6..6f40796d 100644 --- a/src/aleph/web/controllers/prices.py +++ b/src/aleph/web/controllers/prices.py @@ -174,7 +174,7 @@ async def message_price_estimate(request: web.Request): async def recalculate_message_costs(request: web.Request): """Force recalculation of message costs in chronological order with historical pricing. - + This endpoint will: 1. Get all messages that need cost recalculation (if item_hash provided, just that message) 2. Get the pricing timeline to track price changes over time @@ -183,15 +183,15 @@ async def recalculate_message_costs(request: web.Request): 5. Delete existing cost entries and recalculate with historical pricing 6. Store the new cost calculations """ - + session_factory = get_session_factory_from_request(request) - + # Check if a specific message hash was provided item_hash_param = request.match_info.get("item_hash") - + with session_factory() as session: messages_to_recalculate: List[MessageDb] = [] - + if item_hash_param: # Recalculate costs for a specific message try: @@ -203,78 +203,97 @@ async def recalculate_message_costs(request: web.Request): # Recalculate costs for all executable messages, ordered by time (oldest first) select_stmt = ( select(MessageDb) - .where(MessageDb.type.in_([MessageType.instance, MessageType.program, MessageType.store])) + .where( + MessageDb.type.in_( + [MessageType.instance, MessageType.program, MessageType.store] + ) + ) .order_by(MessageDb.time.asc()) ) result = session.execute(select_stmt) messages_to_recalculate = result.scalars().all() - + if not messages_to_recalculate: return web.json_response( - {"message": "No messages found for cost recalculation", "recalculated_count": 0} + { + "message": "No messages found for cost recalculation", + "recalculated_count": 0, + } ) - + # Get the pricing timeline to track price changes over time pricing_timeline = get_pricing_timeline(session) LOGGER.info(f"Found {len(pricing_timeline)} pricing changes in timeline") - + recalculated_count = 0 errors = [] current_pricing_model = None current_pricing_index = 0 - + + settings = _get_settings(session) + for message in messages_to_recalculate: try: # Find the applicable pricing model for this message's timestamp - while (current_pricing_index < len(pricing_timeline) - 1 and - pricing_timeline[current_pricing_index + 1][0] <= message.time): + while ( + current_pricing_index < len(pricing_timeline) - 1 + and pricing_timeline[current_pricing_index + 1][0] <= message.time + ): current_pricing_index += 1 - + current_pricing_model = pricing_timeline[current_pricing_index][1] pricing_timestamp = pricing_timeline[current_pricing_index][0] - - LOGGER.debug(f"Message {message.item_hash} at {message.time} using pricing from {pricing_timestamp}") - + + LOGGER.debug( + f"Message {message.item_hash} at {message.time} using pricing from {pricing_timestamp}" + ) + # Delete existing cost entries for this message delete_costs_for_message(session, message.item_hash) - + # Get the message content and determine product type content: ExecutableContent = message.parsed_content - product_type = _get_product_price_type(content, None, current_pricing_model) - + product_type = _get_product_price_type( + content, settings, current_pricing_model + ) + # Get the pricing for this specific product type if product_type not in current_pricing_model: - LOGGER.warning(f"Product type {product_type} not found in pricing model for message {message.item_hash}") + LOGGER.warning( + f"Product type {product_type} not found in pricing model for message {message.item_hash}" + ) continue - + pricing = current_pricing_model[product_type] - + # Calculate new costs using the historical pricing model - new_costs = get_detailed_costs(session, content, message.item_hash, pricing) - + new_costs = get_detailed_costs( + session, content, message.item_hash, pricing + ) + if new_costs: # Store the new cost calculations upsert_stmt = make_costs_upsert_query(new_costs) session.execute(upsert_stmt) - + recalculated_count += 1 - + except Exception as e: error_msg = f"Failed to recalculate costs for message {message.item_hash}: {str(e)}" LOGGER.error(error_msg) errors.append({"item_hash": message.item_hash, "error": str(e)}) - + # Commit all changes session.commit() - + response_data = { "message": "Cost recalculation completed with historical pricing", "recalculated_count": recalculated_count, "total_messages": len(messages_to_recalculate), - "pricing_changes_found": len(pricing_timeline) + "pricing_changes_found": len(pricing_timeline), } - + if errors: response_data["errors"] = errors - + return web.json_response(response_data) diff --git a/src/aleph/web/controllers/routes.py b/src/aleph/web/controllers/routes.py index 297bf940..ce931480 100644 --- a/src/aleph/web/controllers/routes.py +++ b/src/aleph/web/controllers/routes.py @@ -68,7 +68,9 @@ def register_routes(app: web.Application): app.router.add_get("/api/v0/price/{item_hash}", prices.message_price) app.router.add_post("/api/v0/price/estimate", prices.message_price_estimate) app.router.add_post("/api/v0/price/recalculate", prices.recalculate_message_costs) - app.router.add_post("/api/v0/price/{item_hash}/recalculate", prices.recalculate_message_costs) + app.router.add_post( + "/api/v0/price/{item_hash}/recalculate", prices.recalculate_message_costs + ) app.router.add_get("/api/v0/addresses/stats.json", accounts.addresses_stats_view) app.router.add_get( diff --git a/tests/api/test_pricing_recalculation.py b/tests/api/test_pricing_recalculation.py index 44e4d1c0..f42d9015 100644 --- a/tests/api/test_pricing_recalculation.py +++ b/tests/api/test_pricing_recalculation.py @@ -10,12 +10,8 @@ from aleph.db.models import MessageDb from aleph.db.models.account_costs import AccountCostsDb from aleph.db.models.aggregates import AggregateElementDb -from aleph.toolkit.constants import ( - PRICE_AGGREGATE_KEY, - PRICE_AGGREGATE_OWNER, -) +from aleph.toolkit.constants import PRICE_AGGREGATE_KEY, PRICE_AGGREGATE_OWNER from aleph.types.cost import ProductPriceType -from aleph.types.message_status import MessageStatus from aleph.web.controllers.prices import recalculate_message_costs @@ -23,7 +19,7 @@ def sample_messages(session_factory): """Create sample messages for testing cost recalculation.""" base_time = dt.datetime(2024, 1, 1, 10, 0, 0, tzinfo=dt.timezone.utc) - + # Create sample instance message instance_message = MessageDb( item_hash="instance_msg_1", @@ -48,7 +44,7 @@ def sample_messages(session_factory): time=base_time + dt.timedelta(hours=1), size=1024, ) - + # Create sample program message program_message = MessageDb( item_hash="program_msg_1", @@ -59,7 +55,12 @@ def sample_messages(session_factory): content={ "time": (base_time + dt.timedelta(hours=2)).timestamp(), "on": {"http": True, "persistent": False}, - "code": {"ref": "code_ref", "encoding": "zip", "entrypoint": "main:app", "use_latest": True}, + "code": { + "ref": "code_ref", + "encoding": "zip", + "entrypoint": "main:app", + "use_latest": True, + }, "runtime": {"ref": "runtime_ref", "use_latest": True}, "address": "0xTest2", "resources": {"vcpus": 1, "memory": 128, "seconds": 30}, @@ -69,7 +70,7 @@ def sample_messages(session_factory): time=base_time + dt.timedelta(hours=2), size=512, ) - + # Create sample store message store_message = MessageDb( item_hash="store_msg_1", @@ -86,16 +87,16 @@ def sample_messages(session_factory): time=base_time + dt.timedelta(hours=3), size=2048, ) - + with session_factory() as session: session.add(instance_message) - session.add(program_message) + session.add(program_message) session.add(store_message) session.commit() session.refresh(instance_message) session.refresh(program_message) session.refresh(store_message) - + return [instance_message, program_message, store_message] @@ -103,27 +104,25 @@ def sample_messages(session_factory): def pricing_updates_with_timeline(session_factory): """Create pricing updates that form a timeline for testing.""" base_time = dt.datetime(2024, 1, 1, 9, 0, 0, tzinfo=dt.timezone.utc) - + # First pricing update - before any messages element1 = AggregateElementDb( item_hash="pricing_1", key=PRICE_AGGREGATE_KEY, owner=PRICE_AGGREGATE_OWNER, content={ - ProductPriceType.STORAGE: { - "price": {"storage": {"holding": "0.1"}} - }, + ProductPriceType.STORAGE: {"price": {"storage": {"holding": "0.1"}}}, ProductPriceType.INSTANCE: { "price": { "storage": {"holding": "0.05"}, "compute_unit": {"holding": "500"}, }, "compute_unit": {"vcpus": 1, "disk_mib": 20480, "memory_mib": 2048}, - } + }, }, - creation_datetime=base_time + dt.timedelta(minutes=30) + creation_datetime=base_time + dt.timedelta(minutes=30), ) - + # Second pricing update - between instance and program messages element2 = AggregateElementDb( item_hash="pricing_2", @@ -138,9 +137,9 @@ def pricing_updates_with_timeline(session_factory): "compute_unit": {"vcpus": 1, "disk_mib": 2048, "memory_mib": 2048}, } }, - creation_datetime=base_time + dt.timedelta(hours=1, minutes=30) + creation_datetime=base_time + dt.timedelta(hours=1, minutes=30), ) - + # Third pricing update - after program but before store message element3 = AggregateElementDb( item_hash="pricing_3", @@ -151,9 +150,9 @@ def pricing_updates_with_timeline(session_factory): "price": {"storage": {"holding": "0.2"}} # Updated storage price } }, - creation_datetime=base_time + dt.timedelta(hours=2, minutes=30) + creation_datetime=base_time + dt.timedelta(hours=2, minutes=30), ) - + with session_factory() as session: session.add(element1) session.add(element2) @@ -162,7 +161,7 @@ def pricing_updates_with_timeline(session_factory): session.refresh(element1) session.refresh(element2) session.refresh(element3) - + return [element1, element2, element3] @@ -170,7 +169,7 @@ def pricing_updates_with_timeline(session_factory): def existing_costs(session_factory, sample_messages): """Create some existing cost entries to test deletion and recalculation.""" costs = [] - + for message in sample_messages: cost = AccountCostsDb( owner=message.sender, @@ -182,143 +181,176 @@ def existing_costs(session_factory, sample_messages): cost_stream=Decimal("0.01"), ) costs.append(cost) - + with session_factory() as session: for cost in costs: session.add(cost) session.commit() - + return costs class TestRecalculateMessageCosts: """Tests for the message cost recalculation endpoint.""" - + @pytest.fixture def mock_request_factory(self, session_factory): """Factory to create mock requests.""" + def _create_mock_request(match_info=None): request = web.Request.__new__(web.Request) request._match_info = match_info or {} - + # Mock the session factory getter def get_session_factory(): return session_factory - + request._session_factory = get_session_factory return request - + return _create_mock_request - - @patch('aleph.web.controllers.prices.get_session_factory_from_request') - async def test_recalculate_all_messages_empty_db(self, mock_get_session, session_factory, mock_request_factory): + + @patch("aleph.web.controllers.prices.get_session_factory_from_request") + async def test_recalculate_all_messages_empty_db( + self, mock_get_session, session_factory, mock_request_factory + ): """Test recalculation when no messages exist.""" mock_get_session.return_value = session_factory request = mock_request_factory() - + response = await recalculate_message_costs(request) - + assert response.status == 200 response_data = json.loads(response.text) assert response_data["recalculated_count"] == 0 assert response_data["total_messages"] == 0 assert "No messages found" in response_data["message"] - - @patch('aleph.web.controllers.prices.get_session_factory_from_request') - @patch('aleph.web.controllers.prices.get_executable_message') - async def test_recalculate_specific_message(self, mock_get_executable, mock_get_session, session_factory, sample_messages, mock_request_factory): + + @patch("aleph.web.controllers.prices.get_session_factory_from_request") + @patch("aleph.web.controllers.prices.get_executable_message") + async def test_recalculate_specific_message( + self, + mock_get_executable, + mock_get_session, + session_factory, + sample_messages, + mock_request_factory, + ): """Test recalculation of a specific message.""" mock_get_session.return_value = session_factory mock_get_executable.return_value = sample_messages[0] # Return first message - + request = mock_request_factory({"item_hash": "instance_msg_1"}) - - with patch('aleph.web.controllers.prices.get_detailed_costs') as mock_get_costs: + + with patch("aleph.web.controllers.prices.get_detailed_costs") as mock_get_costs: mock_get_costs.return_value = [] # Mock empty costs - + response = await recalculate_message_costs(request) - + assert response.status == 200 response_data = json.loads(response.text) assert response_data["recalculated_count"] == 1 assert response_data["total_messages"] == 1 assert "historical pricing" in response_data["message"] - + # Should have called get_detailed_costs once assert mock_get_costs.call_count == 1 - - @patch('aleph.web.controllers.prices.get_session_factory_from_request') - async def test_recalculate_all_messages_with_timeline(self, mock_get_session, session_factory, sample_messages, pricing_updates_with_timeline, existing_costs, mock_request_factory): + + @patch("aleph.web.controllers.prices.get_session_factory_from_request") + async def test_recalculate_all_messages_with_timeline( + self, + mock_get_session, + session_factory, + sample_messages, + pricing_updates_with_timeline, + existing_costs, + mock_request_factory, + ): """Test recalculation of all messages with pricing timeline.""" mock_get_session.return_value = session_factory - + request = mock_request_factory() - - with patch('aleph.web.controllers.prices.get_detailed_costs') as mock_get_costs: + + with patch("aleph.web.controllers.prices.get_detailed_costs") as mock_get_costs: mock_get_costs.return_value = [] # Mock empty costs - + response = await recalculate_message_costs(request) - + assert response.status == 200 response_data = json.loads(response.text) assert response_data["recalculated_count"] == 3 assert response_data["total_messages"] == 3 assert response_data["pricing_changes_found"] == 4 # Default + 3 updates - + # Should have called get_detailed_costs for each message assert mock_get_costs.call_count == 3 - + # Verify old costs were deleted with session_factory() as session: remaining_costs = session.query(AccountCostsDb).all() assert len(remaining_costs) == 0 # All old costs should be deleted - - @patch('aleph.web.controllers.prices.get_session_factory_from_request') - async def test_recalculate_with_pricing_timeline_application(self, mock_get_session, session_factory, sample_messages, pricing_updates_with_timeline, mock_request_factory): + + @patch("aleph.web.controllers.prices.get_session_factory_from_request") + async def test_recalculate_with_pricing_timeline_application( + self, + mock_get_session, + session_factory, + sample_messages, + pricing_updates_with_timeline, + mock_request_factory, + ): """Test that the correct pricing model is applied based on message timestamps.""" mock_get_session.return_value = session_factory - + request = mock_request_factory() - + pricing_calls = [] - + def mock_get_costs(session, content, item_hash, pricing): # Capture the pricing object used for each call pricing_calls.append((item_hash, pricing.type if pricing else None)) return [] - - with patch('aleph.web.controllers.prices.get_detailed_costs', side_effect=mock_get_costs): + + with patch( + "aleph.web.controllers.prices.get_detailed_costs", + side_effect=mock_get_costs, + ): response = await recalculate_message_costs(request) - + assert response.status == 200 - + # Should have made calls for all 3 messages assert len(pricing_calls) == 3 - + # Verify the correct pricing types were used (based on message content and timeline) item_hashes = [call[0] for call in pricing_calls] assert "instance_msg_1" in item_hashes assert "program_msg_1" in item_hashes assert "store_msg_1" in item_hashes - - @patch('aleph.web.controllers.prices.get_session_factory_from_request') - async def test_recalculate_with_errors(self, mock_get_session, session_factory, sample_messages, mock_request_factory): + + @patch("aleph.web.controllers.prices.get_session_factory_from_request") + async def test_recalculate_with_errors( + self, mock_get_session, session_factory, sample_messages, mock_request_factory + ): """Test recalculation handling of errors.""" mock_get_session.return_value = session_factory - + request = mock_request_factory() - + def mock_get_costs_with_error(session, content, item_hash, pricing): if item_hash == "program_msg_1": raise ValueError("Test error for program message") return [] - - with patch('aleph.web.controllers.prices.get_detailed_costs', side_effect=mock_get_costs_with_error): + + with patch( + "aleph.web.controllers.prices.get_detailed_costs", + side_effect=mock_get_costs_with_error, + ): response = await recalculate_message_costs(request) - + assert response.status == 200 response_data = json.loads(response.text) - + # Should have processed 2 successfully, 1 with error assert response_data["recalculated_count"] == 2 assert response_data["total_messages"] == 3 @@ -326,37 +358,48 @@ def mock_get_costs_with_error(session, content, item_hash, pricing): assert len(response_data["errors"]) == 1 assert response_data["errors"][0]["item_hash"] == "program_msg_1" assert "Test error" in response_data["errors"][0]["error"] - - @patch('aleph.web.controllers.prices.get_session_factory_from_request') - @patch('aleph.web.controllers.prices.get_executable_message') - async def test_recalculate_specific_message_not_found(self, mock_get_executable, mock_get_session, session_factory, mock_request_factory): + + @patch("aleph.web.controllers.prices.get_session_factory_from_request") + @patch("aleph.web.controllers.prices.get_executable_message") + async def test_recalculate_specific_message_not_found( + self, + mock_get_executable, + mock_get_session, + session_factory, + mock_request_factory, + ): """Test recalculation of a specific message that doesn't exist.""" mock_get_session.return_value = session_factory mock_get_executable.side_effect = web.HTTPNotFound(body="Message not found") - + request = mock_request_factory({"item_hash": "nonexistent_hash"}) - + with pytest.raises(web.HTTPNotFound): await recalculate_message_costs(request) - - @patch('aleph.web.controllers.prices.get_session_factory_from_request') - async def test_chronological_processing_order(self, mock_get_session, session_factory, sample_messages, mock_request_factory): + + @patch("aleph.web.controllers.prices.get_session_factory_from_request") + async def test_chronological_processing_order( + self, mock_get_session, session_factory, sample_messages, mock_request_factory + ): """Test that messages are processed in chronological order.""" mock_get_session.return_value = session_factory - + request = mock_request_factory() - + processed_order = [] - + def mock_get_costs(session, content, item_hash, pricing): processed_order.append(item_hash) return [] - - with patch('aleph.web.controllers.prices.get_detailed_costs', side_effect=mock_get_costs): + + with patch( + "aleph.web.controllers.prices.get_detailed_costs", + side_effect=mock_get_costs, + ): response = await recalculate_message_costs(request) - + assert response.status == 200 - + # Should have processed in chronological order based on message.time expected_order = ["instance_msg_1", "program_msg_1", "store_msg_1"] assert processed_order == expected_order @@ -364,47 +407,60 @@ def mock_get_costs(session, content, item_hash, pricing): class TestPricingTimelineIntegration: """Integration tests for the complete pricing timeline feature.""" - + @pytest.fixture def mock_request_factory(self, session_factory): """Factory to create mock requests.""" + def _create_mock_request(match_info=None): request = web.Request.__new__(web.Request) request._match_info = match_info or {} - + # Mock the session factory getter def get_session_factory(): return session_factory - + request._session_factory = get_session_factory return request - + return _create_mock_request - - @patch('aleph.web.controllers.prices.get_session_factory_from_request') - async def test_end_to_end_historical_pricing(self, mock_get_session, session_factory, sample_messages, pricing_updates_with_timeline, mock_request_factory): + + @patch("aleph.web.controllers.prices.get_session_factory_from_request") + async def test_end_to_end_historical_pricing( + self, + mock_get_session, + session_factory, + sample_messages, + pricing_updates_with_timeline, + mock_request_factory, + ): """End-to-end test of historical pricing application.""" mock_get_session.return_value = session_factory - + request = mock_request_factory() - + # Track which pricing models are used for each message pricing_usage = {} - + def mock_get_costs(session, content, item_hash, pricing): - if pricing and hasattr(pricing, 'price'): - if hasattr(pricing.price, 'storage') and hasattr(pricing.price.storage, 'holding'): + if pricing and hasattr(pricing, "price"): + if hasattr(pricing.price, "storage") and hasattr( + pricing.price.storage, "holding" + ): pricing_usage[item_hash] = float(pricing.price.storage.holding) return [] - - with patch('aleph.web.controllers.prices.get_detailed_costs', side_effect=mock_get_costs): + + with patch( + "aleph.web.controllers.prices.get_detailed_costs", + side_effect=mock_get_costs, + ): response = await recalculate_message_costs(request) - + assert response.status == 200 response_data = json.loads(response.text) assert response_data["recalculated_count"] == 3 - + # Verify that different pricing was applied based on timeline # The exact values depend on the pricing timeline and merge logic, # but we can verify that historical pricing was considered - assert len(pricing_usage) > 0 \ No newline at end of file + assert len(pricing_usage) > 0 diff --git a/tests/services/test_pricing_utils.py b/tests/services/test_pricing_utils.py index 1cade1ff..5f182e69 100644 --- a/tests/services/test_pricing_utils.py +++ b/tests/services/test_pricing_utils.py @@ -1,32 +1,25 @@ import datetime as dt from decimal import Decimal -from unittest.mock import patch import pytest -from aleph_message.models import InstanceContent, ProgramContent, StoreContent +from aleph_message.models import InstanceContent from aleph.db.models.aggregates import AggregateElementDb from aleph.services.pricing_utils import ( build_default_pricing_model, build_pricing_model_from_aggregate, - get_pricing_timeline, get_pricing_aggregate_history, + get_pricing_timeline, ) -from aleph.toolkit.constants import ( - PRICE_AGGREGATE_KEY, - PRICE_AGGREGATE_OWNER, -) +from aleph.toolkit.constants import PRICE_AGGREGATE_KEY, PRICE_AGGREGATE_OWNER from aleph.types.cost import ProductPriceType, ProductPricing -from aleph.types.db_session import DbSessionFactory @pytest.fixture def sample_pricing_aggregate_content(): """Sample pricing aggregate content with ProductPriceType keys.""" return { - ProductPriceType.STORAGE: { - "price": {"storage": {"holding": "0.5"}} - }, + ProductPriceType.STORAGE: {"price": {"storage": {"holding": "0.5"}}}, ProductPriceType.PROGRAM: { "price": { "storage": {"payg": "0.000001", "holding": "0.1"}, @@ -56,23 +49,19 @@ def sample_pricing_aggregate_content(): def pricing_aggregate_elements(session_factory): """Create sample pricing aggregate elements for timeline testing.""" base_time = dt.datetime(2024, 1, 1, 12, 0, 0, tzinfo=dt.timezone.utc) - + # First pricing update - only storage pricing element1 = AggregateElementDb( item_hash="pricing_update_1", key=PRICE_AGGREGATE_KEY, owner=PRICE_AGGREGATE_OWNER, - content={ - ProductPriceType.STORAGE: { - "price": {"storage": {"holding": "0.2"}} - } - }, - creation_datetime=base_time + dt.timedelta(hours=1) + content={ProductPriceType.STORAGE: {"price": {"storage": {"holding": "0.2"}}}}, + creation_datetime=base_time + dt.timedelta(hours=1), ) - + # Second pricing update - adds program pricing element2 = AggregateElementDb( - item_hash="pricing_update_2", + item_hash="pricing_update_2", key=PRICE_AGGREGATE_KEY, owner=PRICE_AGGREGATE_OWNER, content={ @@ -88,18 +77,16 @@ def pricing_aggregate_elements(session_factory): }, } }, - creation_datetime=base_time + dt.timedelta(hours=2) + creation_datetime=base_time + dt.timedelta(hours=2), ) - + # Third pricing update - updates storage and adds instance pricing element3 = AggregateElementDb( item_hash="pricing_update_3", key=PRICE_AGGREGATE_KEY, owner=PRICE_AGGREGATE_OWNER, content={ - ProductPriceType.STORAGE: { - "price": {"storage": {"holding": "0.3"}} - }, + ProductPriceType.STORAGE: {"price": {"storage": {"holding": "0.3"}}}, ProductPriceType.INSTANCE: { "price": { "storage": {"payg": "0.000001", "holding": "0.05"}, @@ -110,11 +97,11 @@ def pricing_aggregate_elements(session_factory): "disk_mib": 20480, "memory_mib": 2048, }, - } + }, }, - creation_datetime=base_time + dt.timedelta(hours=3) + creation_datetime=base_time + dt.timedelta(hours=3), ) - + with session_factory() as session: session.add(element1) session.add(element2) @@ -123,42 +110,42 @@ def pricing_aggregate_elements(session_factory): session.refresh(element1) session.refresh(element2) session.refresh(element3) - + return [element1, element2, element3] class TestBuildPricingModelFromAggregate: """Tests for building pricing models from aggregate content.""" - + def test_build_pricing_model_from_aggregate(self, sample_pricing_aggregate_content): """Test building pricing model from aggregate content.""" - pricing_model = build_pricing_model_from_aggregate(sample_pricing_aggregate_content) - + pricing_model = build_pricing_model_from_aggregate( + sample_pricing_aggregate_content + ) + # Check that we got ProductPricing objects for each type assert len(pricing_model) == 3 assert ProductPriceType.STORAGE in pricing_model assert ProductPriceType.PROGRAM in pricing_model assert ProductPriceType.INSTANCE in pricing_model - + # Check that each is a ProductPricing object for price_type, pricing in pricing_model.items(): assert isinstance(pricing, ProductPricing) assert pricing.type == price_type - + def test_build_pricing_model_with_missing_types(self): """Test building pricing model with some missing product types.""" partial_content = { - ProductPriceType.STORAGE: { - "price": {"storage": {"holding": "0.5"}} - } + ProductPriceType.STORAGE: {"price": {"storage": {"holding": "0.5"}}} } - + pricing_model = build_pricing_model_from_aggregate(partial_content) - + assert len(pricing_model) == 1 assert ProductPriceType.STORAGE in pricing_model assert ProductPriceType.PROGRAM not in pricing_model - + def test_build_pricing_model_with_invalid_data(self): """Test building pricing model with invalid pricing data.""" invalid_content = { @@ -166,7 +153,7 @@ def test_build_pricing_model_with_invalid_data(self): "invalid": "data" # Missing required "price" key } } - + # Should handle the error gracefully and return empty model pricing_model = build_pricing_model_from_aggregate(invalid_content) assert len(pricing_model) == 0 @@ -174,11 +161,11 @@ def test_build_pricing_model_with_invalid_data(self): class TestBuildDefaultPricingModel: """Tests for building the default pricing model.""" - + def test_build_default_pricing_model(self): """Test building the default pricing model from constants.""" pricing_model = build_default_pricing_model() - + # Should contain all the expected product types from DEFAULT_PRICE_AGGREGATE expected_types = [ ProductPriceType.PROGRAM, @@ -190,7 +177,7 @@ def test_build_default_pricing_model(self): ProductPriceType.INSTANCE_GPU_STANDARD, ProductPriceType.WEB3_HOSTING, ] - + for price_type in expected_types: assert price_type in pricing_model assert isinstance(pricing_model[price_type], ProductPricing) @@ -198,24 +185,26 @@ def test_build_default_pricing_model(self): class TestGetPricingAggregateHistory: """Tests for retrieving pricing aggregate history.""" - + def test_get_pricing_aggregate_history_empty(self, session_factory): """Test getting pricing history when no elements exist.""" with session_factory() as session: history = get_pricing_aggregate_history(session) assert len(history) == 0 - - def test_get_pricing_aggregate_history_with_elements(self, session_factory, pricing_aggregate_elements): + + def test_get_pricing_aggregate_history_with_elements( + self, session_factory, pricing_aggregate_elements + ): """Test getting pricing history with existing elements.""" with session_factory() as session: history = get_pricing_aggregate_history(session) - + assert len(history) == 3 - + # Should be ordered chronologically assert history[0].creation_datetime < history[1].creation_datetime assert history[1].creation_datetime < history[2].creation_datetime - + # Check content assert ProductPriceType.STORAGE in history[0].content assert ProductPriceType.PROGRAM in history[1].content @@ -225,81 +214,85 @@ def test_get_pricing_aggregate_history_with_elements(self, session_factory, pric class TestGetPricingTimeline: """Tests for getting the pricing timeline.""" - + def test_get_pricing_timeline_empty(self, session_factory): """Test getting pricing timeline when no aggregate elements exist.""" with session_factory() as session: timeline = get_pricing_timeline(session) - + # Should have at least the default pricing assert len(timeline) == 1 timestamp, pricing_model = timeline[0] - + # Should use minimum datetime for default pricing assert timestamp == dt.datetime.min.replace(tzinfo=dt.timezone.utc) assert isinstance(pricing_model, dict) assert ProductPriceType.STORAGE in pricing_model - - def test_get_pricing_timeline_with_elements(self, session_factory, pricing_aggregate_elements): + + def test_get_pricing_timeline_with_elements( + self, session_factory, pricing_aggregate_elements + ): """Test getting pricing timeline with aggregate elements.""" with session_factory() as session: timeline = get_pricing_timeline(session) - + # Should have default + 3 pricing updates assert len(timeline) == 4 - + # Check chronological order for i in range(len(timeline) - 1): assert timeline[i][0] <= timeline[i + 1][0] - + # Check content evolution default_timestamp, default_model = timeline[0] first_timestamp, first_model = timeline[1] second_timestamp, second_model = timeline[2] third_timestamp, third_model = timeline[3] - + # First update: only storage pricing assert ProductPriceType.STORAGE in first_model storage_pricing_1 = first_model[ProductPriceType.STORAGE] assert storage_pricing_1.price.storage.holding == Decimal("0.2") - + # Second update: storage + program pricing (cumulative) assert ProductPriceType.STORAGE in second_model assert ProductPriceType.PROGRAM in second_model storage_pricing_2 = second_model[ProductPriceType.STORAGE] - assert storage_pricing_2.price.storage.holding == Decimal("0.2") # Still from first update - + assert storage_pricing_2.price.storage.holding == Decimal( + "0.2" + ) # Still from first update + # Third update: updated storage + program + instance pricing (cumulative) assert ProductPriceType.STORAGE in third_model assert ProductPriceType.PROGRAM in third_model assert ProductPriceType.INSTANCE in third_model storage_pricing_3 = third_model[ProductPriceType.STORAGE] - assert storage_pricing_3.price.storage.holding == Decimal("0.3") # Updated value - + assert storage_pricing_3.price.storage.holding == Decimal( + "0.3" + ) # Updated value + def test_pricing_timeline_cumulative_merging(self, session_factory): """Test that pricing timeline properly merges cumulative changes.""" base_time = dt.datetime(2024, 1, 1, 12, 0, 0, tzinfo=dt.timezone.utc) - + # Create elements that update different parts of pricing element1 = AggregateElementDb( item_hash="test_1", key=PRICE_AGGREGATE_KEY, owner=PRICE_AGGREGATE_OWNER, content={ - ProductPriceType.STORAGE: { - "price": {"storage": {"holding": "1.0"}} - }, + ProductPriceType.STORAGE: {"price": {"storage": {"holding": "1.0"}}}, ProductPriceType.PROGRAM: { "price": { "storage": {"holding": "0.1"}, "compute_unit": {"holding": "100"}, }, "compute_unit": {"vcpus": 1, "disk_mib": 1024, "memory_mib": 1024}, - } + }, }, - creation_datetime=base_time + dt.timedelta(hours=1) + creation_datetime=base_time + dt.timedelta(hours=1), ) - + # Second element only updates storage, should preserve program settings element2 = AggregateElementDb( item_hash="test_2", @@ -310,29 +303,29 @@ def test_pricing_timeline_cumulative_merging(self, session_factory): "price": {"storage": {"holding": "2.0"}} # Updated price } }, - creation_datetime=base_time + dt.timedelta(hours=2) + creation_datetime=base_time + dt.timedelta(hours=2), ) - + with session_factory() as session: session.add(element1) session.add(element2) session.commit() - + timeline = get_pricing_timeline(session) - + # Should have default + 2 updates assert len(timeline) == 3 - + # Check final state has both storage and program, with updated storage price final_timestamp, final_model = timeline[2] - + assert ProductPriceType.STORAGE in final_model assert ProductPriceType.PROGRAM in final_model - + # Storage should have updated price storage_pricing = final_model[ProductPriceType.STORAGE] assert storage_pricing.price.storage.holding == Decimal("2.0") - + # Program should still have original settings program_pricing = final_model[ProductPriceType.PROGRAM] assert program_pricing.price.storage.holding == Decimal("0.1") @@ -341,61 +334,74 @@ def test_pricing_timeline_cumulative_merging(self, session_factory): class TestPricingTimelineIntegration: """Integration tests for pricing timeline with real message types.""" - + @pytest.fixture def sample_instance_content(self): """Sample instance content for testing.""" - return InstanceContent.model_validate({ - "time": 1701099523.849, - "rootfs": { - "parent": { - "ref": "549ec451d9b099cad112d4aaa2c00ac40fb6729a92ff252ff22eef0b5c3cb613", - "use_latest": True, + return InstanceContent.model_validate( + { + "time": 1701099523.849, + "rootfs": { + "parent": { + "ref": "549ec451d9b099cad112d4aaa2c00ac40fb6729a92ff252ff22eef0b5c3cb613", + "use_latest": True, + }, + "size_mib": 20480, + "persistence": "host", }, - "size_mib": 20480, - "persistence": "host", - }, - "address": "0xTest", - "volumes": [], - "metadata": {"name": "Test Instance"}, - "resources": {"vcpus": 1, "memory": 2048, "seconds": 30}, - "allow_amend": False, - "environment": {"internet": True, "aleph_api": True}, - }) - - def test_pricing_timeline_with_message_processing(self, session_factory, pricing_aggregate_elements, sample_instance_content): + "address": "0xTest", + "volumes": [], + "metadata": {"name": "Test Instance"}, + "resources": {"vcpus": 1, "memory": 2048, "seconds": 30}, + "allow_amend": False, + "environment": {"internet": True, "aleph_api": True}, + } + ) + + def test_pricing_timeline_with_message_processing( + self, session_factory, pricing_aggregate_elements, sample_instance_content + ): """Test that pricing timeline can be used for historical message cost calculation.""" with session_factory() as session: timeline = get_pricing_timeline(session) - + # Simulate processing a message at different points in time - message_time_1 = dt.datetime(2024, 1, 1, 13, 30, 0, tzinfo=dt.timezone.utc) # Between first and second update - message_time_2 = dt.datetime(2024, 1, 1, 15, 30, 0, tzinfo=dt.timezone.utc) # After all updates - + message_time_1 = dt.datetime( + 2024, 1, 1, 13, 30, 0, tzinfo=dt.timezone.utc + ) # Between first and second update + message_time_2 = dt.datetime( + 2024, 1, 1, 15, 30, 0, tzinfo=dt.timezone.utc + ) # After all updates + # Find applicable pricing for each message time pricing_1 = None pricing_2 = None - + for timestamp, pricing_model in timeline: if timestamp <= message_time_1: pricing_1 = pricing_model if timestamp <= message_time_2: pricing_2 = pricing_model - + # At time 1, should have storage pricing but not instance pricing assert pricing_1 is not None assert ProductPriceType.STORAGE in pricing_1 # Instance pricing not added until third update assert ProductPriceType.INSTANCE not in pricing_1 - + # At time 2, should have both storage and instance pricing assert pricing_2 is not None assert ProductPriceType.STORAGE in pricing_2 assert ProductPriceType.INSTANCE in pricing_2 - + # Storage pricing should be different between the two time points - if ProductPriceType.STORAGE in pricing_1 and ProductPriceType.STORAGE in pricing_2: + if ( + ProductPriceType.STORAGE in pricing_1 + and ProductPriceType.STORAGE in pricing_2 + ): storage_1 = pricing_1[ProductPriceType.STORAGE] storage_2 = pricing_2[ProductPriceType.STORAGE] # Should have different prices due to the third update - assert storage_1.price.storage.holding != storage_2.price.storage.holding \ No newline at end of file + assert ( + storage_1.price.storage.holding != storage_2.price.storage.holding + )