diff --git a/deployment/migrations/versions/0033_1c06d0ade60c_calculate_costs_statically.py b/deployment/migrations/versions/0033_1c06d0ade60c_calculate_costs_statically.py index b1f36c93..6419c90d 100644 --- a/deployment/migrations/versions/0033_1c06d0ade60c_calculate_costs_statically.py +++ b/deployment/migrations/versions/0033_1c06d0ade60c_calculate_costs_statically.py @@ -7,7 +7,6 @@ """ from decimal import Decimal -from typing import Dict from alembic import op import sqlalchemy as sa import logging @@ -17,7 +16,8 @@ from aleph.db.accessors.cost import make_costs_upsert_query from aleph.db.accessors.messages import get_message_by_item_hash from aleph.services.cost import _is_confidential_vm, get_detailed_costs, CostComputableContent -from aleph.types.cost import ProductComputeUnit, ProductPrice, ProductPriceOptions, ProductPriceType, ProductPricing +from aleph.services.pricing_utils import build_default_pricing_model +from aleph.types.cost import ProductPriceType from aleph.types.db_session import DbSession logger = logging.getLogger("alembic") @@ -30,48 +30,6 @@ depends_on = None -hardcoded_initial_price: Dict[ProductPriceType, ProductPricing] = { - ProductPriceType.PROGRAM: ProductPricing( - ProductPriceType.PROGRAM, - ProductPrice( - ProductPriceOptions("0.05", "0.000000977"), - ProductPriceOptions("200", "0.011") - ), - ProductComputeUnit(1, 2048, 2048) - ), - ProductPriceType.PROGRAM_PERSISTENT: ProductPricing( - ProductPriceType.PROGRAM_PERSISTENT, - ProductPrice( - ProductPriceOptions("0.05", "0.000000977"), - ProductPriceOptions("1000", "0.055") - ), - ProductComputeUnit(1, 20480, 2048) - ), - ProductPriceType.INSTANCE: ProductPricing( - ProductPriceType.INSTANCE, - ProductPrice( - ProductPriceOptions("0.05", "0.000000977"), - ProductPriceOptions("1000", "0.055") - ), - ProductComputeUnit(1, 20480, 2048) - ), - ProductPriceType.INSTANCE_CONFIDENTIAL: ProductPricing( - ProductPriceType.INSTANCE_CONFIDENTIAL, - ProductPrice( - ProductPriceOptions("0.05", "0.000000977"), - ProductPriceOptions("2000", "0.11") - ), - ProductComputeUnit(1, 20480, 2048) - ), - ProductPriceType.STORAGE: ProductPricing( - ProductPriceType.STORAGE, - ProductPrice( - ProductPriceOptions("0.333333333"), - ) - ), -} - - def _get_product_instance_type( content: InstanceContent ) -> ProductPriceType: @@ -112,12 +70,15 @@ def do_calculate_costs() -> None: logger.debug("INIT: CALCULATE COSTS FOR: %r", msg_item_hashes) + # Build the initial pricing model from DEFAULT_PRICE_AGGREGATE + initial_pricing_model = build_default_pricing_model() + for item_hash in msg_item_hashes: message = get_message_by_item_hash(session, item_hash) if message: content = message.parsed_content type = _get_product_price_type(content) - pricing = hardcoded_initial_price[type] + pricing = initial_pricing_model[type] costs = get_detailed_costs(session, content, message.item_hash, pricing) if len(costs) > 0: diff --git a/src/aleph/services/pricing_utils.py b/src/aleph/services/pricing_utils.py new file mode 100644 index 00000000..11c1975c --- /dev/null +++ b/src/aleph/services/pricing_utils.py @@ -0,0 +1,122 @@ +""" +Utility functions for pricing model creation and management. +""" + +import datetime as dt +from typing import Dict, List, Union + +from aleph.db.accessors.aggregates import ( + get_aggregate_elements, + merge_aggregate_elements, +) +from aleph.db.models import AggregateElementDb +from aleph.toolkit.constants import ( + DEFAULT_PRICE_AGGREGATE, + PRICE_AGGREGATE_KEY, + PRICE_AGGREGATE_OWNER, +) +from aleph.types.cost import ProductPriceType, ProductPricing +from aleph.types.db_session import DbSession + + +def build_pricing_model_from_aggregate( + aggregate_content: Dict[Union[ProductPriceType, str], dict] +) -> Dict[ProductPriceType, ProductPricing]: + """ + Build a complete pricing model from an aggregate content dictionary. + + This function converts the DEFAULT_PRICE_AGGREGATE format or any pricing aggregate + content into a dictionary of ProductPricing objects that can be used by the cost + calculation functions. + + Args: + aggregate_content: Dictionary containing pricing information with ProductPriceType as keys + + Returns: + Dictionary mapping ProductPriceType to ProductPricing objects + """ + pricing_model: Dict[ProductPriceType, ProductPricing] = {} + + for price_type, pricing_data in aggregate_content.items(): + try: + price_type = ProductPriceType(price_type) + pricing_model[price_type] = ProductPricing.from_aggregate( + price_type, aggregate_content + ) + except (KeyError, ValueError) as e: + # Log the error but continue processing other price types + import logging + + logger = logging.getLogger(__name__) + logger.warning(f"Failed to parse pricing for {price_type}: {e}") + + return pricing_model + + +def build_default_pricing_model() -> Dict[ProductPriceType, ProductPricing]: + """ + Build the default pricing model from DEFAULT_PRICE_AGGREGATE constant. + + Returns: + Dictionary mapping ProductPriceType to ProductPricing objects + """ + return build_pricing_model_from_aggregate(DEFAULT_PRICE_AGGREGATE) + + +def get_pricing_aggregate_history(session: DbSession) -> List[AggregateElementDb]: + """ + Get all pricing aggregate updates in chronological order. + + Args: + session: Database session + + Returns: + List of AggregateElementDb objects ordered by creation_datetime + """ + aggregate_elements = get_aggregate_elements( + session=session, owner=PRICE_AGGREGATE_OWNER, key=PRICE_AGGREGATE_KEY + ) + return list(aggregate_elements) + + +def get_pricing_timeline( + session: DbSession, +) -> List[tuple[dt.datetime, Dict[ProductPriceType, ProductPricing]]]: + """ + Get the complete pricing timeline with timestamps and pricing models. + + This function returns a chronologically ordered list of pricing changes, + useful for processing messages in chronological order and applying the + correct pricing at each point in time. + + This properly merges aggregate elements up to each point in time to create + the cumulative pricing state, similar to how _update_aggregate works. + + Args: + session: Database session + + Returns: + List of tuples containing (timestamp, pricing_model) + """ + pricing_elements = get_pricing_aggregate_history(session) + + timeline = [] + + # Add default pricing as the initial state + timeline.append( + (dt.datetime.min.replace(tzinfo=dt.timezone.utc), build_default_pricing_model()) + ) + + # Build cumulative pricing models by merging elements up to each timestamp + elements_so_far = [] + for element in pricing_elements: + elements_so_far.append(element) + + # Merge all elements up to this point to get the cumulative state + merged_content = merge_aggregate_elements(elements_so_far) + + # Build pricing model from the merged content + pricing_model = build_pricing_model_from_aggregate(merged_content) + timeline.append((element.creation_datetime, pricing_model)) + + return timeline diff --git a/src/aleph/toolkit/constants.py b/src/aleph/toolkit/constants.py index d33ccc1a..129df7d3 100644 --- a/src/aleph/toolkit/constants.py +++ b/src/aleph/toolkit/constants.py @@ -1,3 +1,7 @@ +from typing import Dict, Union + +from aleph.types.cost import ProductPriceType + KiB = 1024 MiB = 1024 * 1024 GiB = 1024 * 1024 * 1024 @@ -8,8 +12,8 @@ PRICE_AGGREGATE_OWNER = "0xFba561a84A537fCaa567bb7A2257e7142701ae2A" PRICE_AGGREGATE_KEY = "pricing" PRICE_PRECISION = 18 -DEFAULT_PRICE_AGGREGATE = { - "program": { +DEFAULT_PRICE_AGGREGATE: Dict[Union[ProductPriceType, str], dict] = { + ProductPriceType.PROGRAM: { "price": { "storage": {"payg": "0.000000977", "holding": "0.05"}, "compute_unit": {"payg": "0.011", "holding": "200"}, @@ -28,8 +32,8 @@ "memory_mib": 2048, }, }, - "storage": {"price": {"storage": {"holding": "0.333333333"}}}, - "instance": { + ProductPriceType.STORAGE: {"price": {"storage": {"holding": "0.333333333"}}}, + ProductPriceType.INSTANCE: { "price": { "storage": {"payg": "0.000000977", "holding": "0.05"}, "compute_unit": {"payg": "0.055", "holding": "1000"}, @@ -48,8 +52,10 @@ "memory_mib": 2048, }, }, - "web3_hosting": {"price": {"fixed": 50, "storage": {"holding": "0.333333333"}}}, - "program_persistent": { + ProductPriceType.WEB3_HOSTING: { + "price": {"fixed": 50, "storage": {"holding": "0.333333333"}} + }, + ProductPriceType.PROGRAM_PERSISTENT: { "price": { "storage": {"payg": "0.000000977", "holding": "0.05"}, "compute_unit": {"payg": "0.055", "holding": "1000"}, @@ -68,7 +74,7 @@ "memory_mib": 2048, }, }, - "instance_gpu_premium": { + ProductPriceType.INSTANCE_GPU_PREMIUM: { "price": { "storage": {"payg": "0.000000977"}, "compute_unit": {"payg": "0.56"}, @@ -93,7 +99,7 @@ "memory_mib": 6144, }, }, - "instance_confidential": { + ProductPriceType.INSTANCE_CONFIDENTIAL: { "price": { "storage": {"payg": "0.000000977", "holding": "0.05"}, "compute_unit": {"payg": "0.11", "holding": "2000"}, @@ -112,7 +118,7 @@ "memory_mib": 2048, }, }, - "instance_gpu_standard": { + ProductPriceType.INSTANCE_GPU_STANDARD: { "price": { "storage": {"payg": "0.000000977"}, "compute_unit": {"payg": "0.28"}, diff --git a/src/aleph/web/controllers/prices.py b/src/aleph/web/controllers/prices.py index 48fc9df3..6f40796d 100644 --- a/src/aleph/web/controllers/prices.py +++ b/src/aleph/web/controllers/prices.py @@ -1,15 +1,17 @@ import logging from dataclasses import dataclass from decimal import Decimal -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from aiohttp import web from aiohttp.web_exceptions import HTTPException from aleph_message.models import ExecutableContent, ItemHash, MessageType from dataclasses_json import DataClassJsonMixin from pydantic import BaseModel, Field +from sqlalchemy import select import aleph.toolkit.json as aleph_json +from aleph.db.accessors.cost import delete_costs_for_message, make_costs_upsert_query from aleph.db.accessors.messages import get_message_by_item_hash, get_message_status from aleph.db.models import MessageDb from aleph.schemas.api.costs import EstimatedCostsResponse @@ -18,10 +20,14 @@ validate_cost_estimation_message_dict, ) from aleph.services.cost import ( + _get_product_price_type, + _get_settings, + get_detailed_costs, get_payment_type, get_total_and_detailed_costs, get_total_and_detailed_costs_from_db, ) +from aleph.services.pricing_utils import get_pricing_timeline from aleph.toolkit.costs import format_cost_str from aleph.types.db_session import DbSession from aleph.types.message_status import MessageStatus @@ -164,3 +170,130 @@ async def message_price_estimate(request: web.Request): response = EstimatedCostsResponse.model_validate(model) return web.json_response(text=aleph_json.dumps(response).decode("utf-8")) + + +async def recalculate_message_costs(request: web.Request): + """Force recalculation of message costs in chronological order with historical pricing. + + This endpoint will: + 1. Get all messages that need cost recalculation (if item_hash provided, just that message) + 2. Get the pricing timeline to track price changes over time + 3. Sort messages chronologically (oldest first) + 4. For each message, use the pricing model that was active when the message was created + 5. Delete existing cost entries and recalculate with historical pricing + 6. Store the new cost calculations + """ + + session_factory = get_session_factory_from_request(request) + + # Check if a specific message hash was provided + item_hash_param = request.match_info.get("item_hash") + + with session_factory() as session: + messages_to_recalculate: List[MessageDb] = [] + + if item_hash_param: + # Recalculate costs for a specific message + try: + message = await get_executable_message(session, item_hash_param) + messages_to_recalculate = [message] + except HTTPException: + raise + else: + # Recalculate costs for all executable messages, ordered by time (oldest first) + select_stmt = ( + select(MessageDb) + .where( + MessageDb.type.in_( + [MessageType.instance, MessageType.program, MessageType.store] + ) + ) + .order_by(MessageDb.time.asc()) + ) + result = session.execute(select_stmt) + messages_to_recalculate = result.scalars().all() + + if not messages_to_recalculate: + return web.json_response( + { + "message": "No messages found for cost recalculation", + "recalculated_count": 0, + } + ) + + # Get the pricing timeline to track price changes over time + pricing_timeline = get_pricing_timeline(session) + LOGGER.info(f"Found {len(pricing_timeline)} pricing changes in timeline") + + recalculated_count = 0 + errors = [] + current_pricing_model = None + current_pricing_index = 0 + + settings = _get_settings(session) + + for message in messages_to_recalculate: + try: + # Find the applicable pricing model for this message's timestamp + while ( + current_pricing_index < len(pricing_timeline) - 1 + and pricing_timeline[current_pricing_index + 1][0] <= message.time + ): + current_pricing_index += 1 + + current_pricing_model = pricing_timeline[current_pricing_index][1] + pricing_timestamp = pricing_timeline[current_pricing_index][0] + + LOGGER.debug( + f"Message {message.item_hash} at {message.time} using pricing from {pricing_timestamp}" + ) + + # Delete existing cost entries for this message + delete_costs_for_message(session, message.item_hash) + + # Get the message content and determine product type + content: ExecutableContent = message.parsed_content + product_type = _get_product_price_type( + content, settings, current_pricing_model + ) + + # Get the pricing for this specific product type + if product_type not in current_pricing_model: + LOGGER.warning( + f"Product type {product_type} not found in pricing model for message {message.item_hash}" + ) + continue + + pricing = current_pricing_model[product_type] + + # Calculate new costs using the historical pricing model + new_costs = get_detailed_costs( + session, content, message.item_hash, pricing + ) + + if new_costs: + # Store the new cost calculations + upsert_stmt = make_costs_upsert_query(new_costs) + session.execute(upsert_stmt) + + recalculated_count += 1 + + except Exception as e: + error_msg = f"Failed to recalculate costs for message {message.item_hash}: {str(e)}" + LOGGER.error(error_msg) + errors.append({"item_hash": message.item_hash, "error": str(e)}) + + # Commit all changes + session.commit() + + response_data = { + "message": "Cost recalculation completed with historical pricing", + "recalculated_count": recalculated_count, + "total_messages": len(messages_to_recalculate), + "pricing_changes_found": len(pricing_timeline), + } + + if errors: + response_data["errors"] = errors + + return web.json_response(response_data) diff --git a/src/aleph/web/controllers/routes.py b/src/aleph/web/controllers/routes.py index b4678fdf..ce931480 100644 --- a/src/aleph/web/controllers/routes.py +++ b/src/aleph/web/controllers/routes.py @@ -67,6 +67,10 @@ def register_routes(app: web.Application): app.router.add_get("/api/v0/price/{item_hash}", prices.message_price) app.router.add_post("/api/v0/price/estimate", prices.message_price_estimate) + app.router.add_post("/api/v0/price/recalculate", prices.recalculate_message_costs) + app.router.add_post( + "/api/v0/price/{item_hash}/recalculate", prices.recalculate_message_costs + ) app.router.add_get("/api/v0/addresses/stats.json", accounts.addresses_stats_view) app.router.add_get( diff --git a/tests/api/test_pricing_recalculation.py b/tests/api/test_pricing_recalculation.py new file mode 100644 index 00000000..f42d9015 --- /dev/null +++ b/tests/api/test_pricing_recalculation.py @@ -0,0 +1,466 @@ +import datetime as dt +import json +from decimal import Decimal +from unittest.mock import patch + +import pytest +from aiohttp import web +from aleph_message.models import MessageType + +from aleph.db.models import MessageDb +from aleph.db.models.account_costs import AccountCostsDb +from aleph.db.models.aggregates import AggregateElementDb +from aleph.toolkit.constants import PRICE_AGGREGATE_KEY, PRICE_AGGREGATE_OWNER +from aleph.types.cost import ProductPriceType +from aleph.web.controllers.prices import recalculate_message_costs + + +@pytest.fixture +def sample_messages(session_factory): + """Create sample messages for testing cost recalculation.""" + base_time = dt.datetime(2024, 1, 1, 10, 0, 0, tzinfo=dt.timezone.utc) + + # Create sample instance message + instance_message = MessageDb( + item_hash="instance_msg_1", + type=MessageType.instance, + chain="ETH", + sender="0xTest1", + item_type="inline", + content={ + "time": (base_time + dt.timedelta(hours=1)).timestamp(), + "rootfs": { + "parent": {"ref": "test_ref", "use_latest": True}, + "size_mib": 20480, + "persistence": "host", + }, + "address": "0xTest1", + "volumes": [], + "metadata": {"name": "Test Instance"}, + "resources": {"vcpus": 1, "memory": 2048, "seconds": 30}, + "allow_amend": False, + "environment": {"internet": True, "aleph_api": True}, + }, + time=base_time + dt.timedelta(hours=1), + size=1024, + ) + + # Create sample program message + program_message = MessageDb( + item_hash="program_msg_1", + type=MessageType.program, + chain="ETH", + sender="0xTest2", + item_type="inline", + content={ + "time": (base_time + dt.timedelta(hours=2)).timestamp(), + "on": {"http": True, "persistent": False}, + "code": { + "ref": "code_ref", + "encoding": "zip", + "entrypoint": "main:app", + "use_latest": True, + }, + "runtime": {"ref": "runtime_ref", "use_latest": True}, + "address": "0xTest2", + "resources": {"vcpus": 1, "memory": 128, "seconds": 30}, + "allow_amend": False, + "environment": {"internet": True, "aleph_api": True}, + }, + time=base_time + dt.timedelta(hours=2), + size=512, + ) + + # Create sample store message + store_message = MessageDb( + item_hash="store_msg_1", + type=MessageType.store, + chain="ETH", + sender="0xTest3", + item_type="inline", + content={ + "time": (base_time + dt.timedelta(hours=3)).timestamp(), + "item_type": "storage", + "item_hash": "stored_file_hash", + "address": "0xTest3", + }, + time=base_time + dt.timedelta(hours=3), + size=2048, + ) + + with session_factory() as session: + session.add(instance_message) + session.add(program_message) + session.add(store_message) + session.commit() + session.refresh(instance_message) + session.refresh(program_message) + session.refresh(store_message) + + return [instance_message, program_message, store_message] + + +@pytest.fixture +def pricing_updates_with_timeline(session_factory): + """Create pricing updates that form a timeline for testing.""" + base_time = dt.datetime(2024, 1, 1, 9, 0, 0, tzinfo=dt.timezone.utc) + + # First pricing update - before any messages + element1 = AggregateElementDb( + item_hash="pricing_1", + key=PRICE_AGGREGATE_KEY, + owner=PRICE_AGGREGATE_OWNER, + content={ + ProductPriceType.STORAGE: {"price": {"storage": {"holding": "0.1"}}}, + ProductPriceType.INSTANCE: { + "price": { + "storage": {"holding": "0.05"}, + "compute_unit": {"holding": "500"}, + }, + "compute_unit": {"vcpus": 1, "disk_mib": 20480, "memory_mib": 2048}, + }, + }, + creation_datetime=base_time + dt.timedelta(minutes=30), + ) + + # Second pricing update - between instance and program messages + element2 = AggregateElementDb( + item_hash="pricing_2", + key=PRICE_AGGREGATE_KEY, + owner=PRICE_AGGREGATE_OWNER, + content={ + ProductPriceType.PROGRAM: { + "price": { + "storage": {"holding": "0.03"}, + "compute_unit": {"holding": "150"}, + }, + "compute_unit": {"vcpus": 1, "disk_mib": 2048, "memory_mib": 2048}, + } + }, + creation_datetime=base_time + dt.timedelta(hours=1, minutes=30), + ) + + # Third pricing update - after program but before store message + element3 = AggregateElementDb( + item_hash="pricing_3", + key=PRICE_AGGREGATE_KEY, + owner=PRICE_AGGREGATE_OWNER, + content={ + ProductPriceType.STORAGE: { + "price": {"storage": {"holding": "0.2"}} # Updated storage price + } + }, + creation_datetime=base_time + dt.timedelta(hours=2, minutes=30), + ) + + with session_factory() as session: + session.add(element1) + session.add(element2) + session.add(element3) + session.commit() + session.refresh(element1) + session.refresh(element2) + session.refresh(element3) + + return [element1, element2, element3] + + +@pytest.fixture +def existing_costs(session_factory, sample_messages): + """Create some existing cost entries to test deletion and recalculation.""" + costs = [] + + for message in sample_messages: + cost = AccountCostsDb( + owner=message.sender, + item_hash=message.item_hash, + type="EXECUTION", + name="old_cost", + payment_type="hold", + cost_hold=Decimal("999.99"), # Old/incorrect cost + cost_stream=Decimal("0.01"), + ) + costs.append(cost) + + with session_factory() as session: + for cost in costs: + session.add(cost) + session.commit() + + return costs + + +class TestRecalculateMessageCosts: + """Tests for the message cost recalculation endpoint.""" + + @pytest.fixture + def mock_request_factory(self, session_factory): + """Factory to create mock requests.""" + + def _create_mock_request(match_info=None): + request = web.Request.__new__(web.Request) + request._match_info = match_info or {} + + # Mock the session factory getter + def get_session_factory(): + return session_factory + + request._session_factory = get_session_factory + return request + + return _create_mock_request + + @patch("aleph.web.controllers.prices.get_session_factory_from_request") + async def test_recalculate_all_messages_empty_db( + self, mock_get_session, session_factory, mock_request_factory + ): + """Test recalculation when no messages exist.""" + mock_get_session.return_value = session_factory + request = mock_request_factory() + + response = await recalculate_message_costs(request) + + assert response.status == 200 + response_data = json.loads(response.text) + assert response_data["recalculated_count"] == 0 + assert response_data["total_messages"] == 0 + assert "No messages found" in response_data["message"] + + @patch("aleph.web.controllers.prices.get_session_factory_from_request") + @patch("aleph.web.controllers.prices.get_executable_message") + async def test_recalculate_specific_message( + self, + mock_get_executable, + mock_get_session, + session_factory, + sample_messages, + mock_request_factory, + ): + """Test recalculation of a specific message.""" + mock_get_session.return_value = session_factory + mock_get_executable.return_value = sample_messages[0] # Return first message + + request = mock_request_factory({"item_hash": "instance_msg_1"}) + + with patch("aleph.web.controllers.prices.get_detailed_costs") as mock_get_costs: + mock_get_costs.return_value = [] # Mock empty costs + + response = await recalculate_message_costs(request) + + assert response.status == 200 + response_data = json.loads(response.text) + assert response_data["recalculated_count"] == 1 + assert response_data["total_messages"] == 1 + assert "historical pricing" in response_data["message"] + + # Should have called get_detailed_costs once + assert mock_get_costs.call_count == 1 + + @patch("aleph.web.controllers.prices.get_session_factory_from_request") + async def test_recalculate_all_messages_with_timeline( + self, + mock_get_session, + session_factory, + sample_messages, + pricing_updates_with_timeline, + existing_costs, + mock_request_factory, + ): + """Test recalculation of all messages with pricing timeline.""" + mock_get_session.return_value = session_factory + + request = mock_request_factory() + + with patch("aleph.web.controllers.prices.get_detailed_costs") as mock_get_costs: + mock_get_costs.return_value = [] # Mock empty costs + + response = await recalculate_message_costs(request) + + assert response.status == 200 + response_data = json.loads(response.text) + assert response_data["recalculated_count"] == 3 + assert response_data["total_messages"] == 3 + assert response_data["pricing_changes_found"] == 4 # Default + 3 updates + + # Should have called get_detailed_costs for each message + assert mock_get_costs.call_count == 3 + + # Verify old costs were deleted + with session_factory() as session: + remaining_costs = session.query(AccountCostsDb).all() + assert len(remaining_costs) == 0 # All old costs should be deleted + + @patch("aleph.web.controllers.prices.get_session_factory_from_request") + async def test_recalculate_with_pricing_timeline_application( + self, + mock_get_session, + session_factory, + sample_messages, + pricing_updates_with_timeline, + mock_request_factory, + ): + """Test that the correct pricing model is applied based on message timestamps.""" + mock_get_session.return_value = session_factory + + request = mock_request_factory() + + pricing_calls = [] + + def mock_get_costs(session, content, item_hash, pricing): + # Capture the pricing object used for each call + pricing_calls.append((item_hash, pricing.type if pricing else None)) + return [] + + with patch( + "aleph.web.controllers.prices.get_detailed_costs", + side_effect=mock_get_costs, + ): + response = await recalculate_message_costs(request) + + assert response.status == 200 + + # Should have made calls for all 3 messages + assert len(pricing_calls) == 3 + + # Verify the correct pricing types were used (based on message content and timeline) + item_hashes = [call[0] for call in pricing_calls] + assert "instance_msg_1" in item_hashes + assert "program_msg_1" in item_hashes + assert "store_msg_1" in item_hashes + + @patch("aleph.web.controllers.prices.get_session_factory_from_request") + async def test_recalculate_with_errors( + self, mock_get_session, session_factory, sample_messages, mock_request_factory + ): + """Test recalculation handling of errors.""" + mock_get_session.return_value = session_factory + + request = mock_request_factory() + + def mock_get_costs_with_error(session, content, item_hash, pricing): + if item_hash == "program_msg_1": + raise ValueError("Test error for program message") + return [] + + with patch( + "aleph.web.controllers.prices.get_detailed_costs", + side_effect=mock_get_costs_with_error, + ): + response = await recalculate_message_costs(request) + + assert response.status == 200 + response_data = json.loads(response.text) + + # Should have processed 2 successfully, 1 with error + assert response_data["recalculated_count"] == 2 + assert response_data["total_messages"] == 3 + assert "errors" in response_data + assert len(response_data["errors"]) == 1 + assert response_data["errors"][0]["item_hash"] == "program_msg_1" + assert "Test error" in response_data["errors"][0]["error"] + + @patch("aleph.web.controllers.prices.get_session_factory_from_request") + @patch("aleph.web.controllers.prices.get_executable_message") + async def test_recalculate_specific_message_not_found( + self, + mock_get_executable, + mock_get_session, + session_factory, + mock_request_factory, + ): + """Test recalculation of a specific message that doesn't exist.""" + mock_get_session.return_value = session_factory + mock_get_executable.side_effect = web.HTTPNotFound(body="Message not found") + + request = mock_request_factory({"item_hash": "nonexistent_hash"}) + + with pytest.raises(web.HTTPNotFound): + await recalculate_message_costs(request) + + @patch("aleph.web.controllers.prices.get_session_factory_from_request") + async def test_chronological_processing_order( + self, mock_get_session, session_factory, sample_messages, mock_request_factory + ): + """Test that messages are processed in chronological order.""" + mock_get_session.return_value = session_factory + + request = mock_request_factory() + + processed_order = [] + + def mock_get_costs(session, content, item_hash, pricing): + processed_order.append(item_hash) + return [] + + with patch( + "aleph.web.controllers.prices.get_detailed_costs", + side_effect=mock_get_costs, + ): + response = await recalculate_message_costs(request) + + assert response.status == 200 + + # Should have processed in chronological order based on message.time + expected_order = ["instance_msg_1", "program_msg_1", "store_msg_1"] + assert processed_order == expected_order + + +class TestPricingTimelineIntegration: + """Integration tests for the complete pricing timeline feature.""" + + @pytest.fixture + def mock_request_factory(self, session_factory): + """Factory to create mock requests.""" + + def _create_mock_request(match_info=None): + request = web.Request.__new__(web.Request) + request._match_info = match_info or {} + + # Mock the session factory getter + def get_session_factory(): + return session_factory + + request._session_factory = get_session_factory + return request + + return _create_mock_request + + @patch("aleph.web.controllers.prices.get_session_factory_from_request") + async def test_end_to_end_historical_pricing( + self, + mock_get_session, + session_factory, + sample_messages, + pricing_updates_with_timeline, + mock_request_factory, + ): + """End-to-end test of historical pricing application.""" + mock_get_session.return_value = session_factory + + request = mock_request_factory() + + # Track which pricing models are used for each message + pricing_usage = {} + + def mock_get_costs(session, content, item_hash, pricing): + if pricing and hasattr(pricing, "price"): + if hasattr(pricing.price, "storage") and hasattr( + pricing.price.storage, "holding" + ): + pricing_usage[item_hash] = float(pricing.price.storage.holding) + return [] + + with patch( + "aleph.web.controllers.prices.get_detailed_costs", + side_effect=mock_get_costs, + ): + response = await recalculate_message_costs(request) + + assert response.status == 200 + response_data = json.loads(response.text) + assert response_data["recalculated_count"] == 3 + + # Verify that different pricing was applied based on timeline + # The exact values depend on the pricing timeline and merge logic, + # but we can verify that historical pricing was considered + assert len(pricing_usage) > 0 diff --git a/tests/services/test_pricing_utils.py b/tests/services/test_pricing_utils.py new file mode 100644 index 00000000..5f182e69 --- /dev/null +++ b/tests/services/test_pricing_utils.py @@ -0,0 +1,407 @@ +import datetime as dt +from decimal import Decimal + +import pytest +from aleph_message.models import InstanceContent + +from aleph.db.models.aggregates import AggregateElementDb +from aleph.services.pricing_utils import ( + build_default_pricing_model, + build_pricing_model_from_aggregate, + get_pricing_aggregate_history, + get_pricing_timeline, +) +from aleph.toolkit.constants import PRICE_AGGREGATE_KEY, PRICE_AGGREGATE_OWNER +from aleph.types.cost import ProductPriceType, ProductPricing + + +@pytest.fixture +def sample_pricing_aggregate_content(): + """Sample pricing aggregate content with ProductPriceType keys.""" + return { + ProductPriceType.STORAGE: {"price": {"storage": {"holding": "0.5"}}}, + ProductPriceType.PROGRAM: { + "price": { + "storage": {"payg": "0.000001", "holding": "0.1"}, + "compute_unit": {"payg": "0.02", "holding": "300"}, + }, + "compute_unit": { + "vcpus": 1, + "disk_mib": 2048, + "memory_mib": 2048, + }, + }, + ProductPriceType.INSTANCE: { + "price": { + "storage": {"payg": "0.000001", "holding": "0.1"}, + "compute_unit": {"payg": "0.1", "holding": "1500"}, + }, + "compute_unit": { + "vcpus": 1, + "disk_mib": 20480, + "memory_mib": 2048, + }, + }, + } + + +@pytest.fixture +def pricing_aggregate_elements(session_factory): + """Create sample pricing aggregate elements for timeline testing.""" + base_time = dt.datetime(2024, 1, 1, 12, 0, 0, tzinfo=dt.timezone.utc) + + # First pricing update - only storage pricing + element1 = AggregateElementDb( + item_hash="pricing_update_1", + key=PRICE_AGGREGATE_KEY, + owner=PRICE_AGGREGATE_OWNER, + content={ProductPriceType.STORAGE: {"price": {"storage": {"holding": "0.2"}}}}, + creation_datetime=base_time + dt.timedelta(hours=1), + ) + + # Second pricing update - adds program pricing + element2 = AggregateElementDb( + item_hash="pricing_update_2", + key=PRICE_AGGREGATE_KEY, + owner=PRICE_AGGREGATE_OWNER, + content={ + ProductPriceType.PROGRAM: { + "price": { + "storage": {"payg": "0.000001", "holding": "0.08"}, + "compute_unit": {"payg": "0.015", "holding": "250"}, + }, + "compute_unit": { + "vcpus": 1, + "disk_mib": 2048, + "memory_mib": 2048, + }, + } + }, + creation_datetime=base_time + dt.timedelta(hours=2), + ) + + # Third pricing update - updates storage and adds instance pricing + element3 = AggregateElementDb( + item_hash="pricing_update_3", + key=PRICE_AGGREGATE_KEY, + owner=PRICE_AGGREGATE_OWNER, + content={ + ProductPriceType.STORAGE: {"price": {"storage": {"holding": "0.3"}}}, + ProductPriceType.INSTANCE: { + "price": { + "storage": {"payg": "0.000001", "holding": "0.05"}, + "compute_unit": {"payg": "0.06", "holding": "1200"}, + }, + "compute_unit": { + "vcpus": 1, + "disk_mib": 20480, + "memory_mib": 2048, + }, + }, + }, + creation_datetime=base_time + dt.timedelta(hours=3), + ) + + with session_factory() as session: + session.add(element1) + session.add(element2) + session.add(element3) + session.commit() + session.refresh(element1) + session.refresh(element2) + session.refresh(element3) + + return [element1, element2, element3] + + +class TestBuildPricingModelFromAggregate: + """Tests for building pricing models from aggregate content.""" + + def test_build_pricing_model_from_aggregate(self, sample_pricing_aggregate_content): + """Test building pricing model from aggregate content.""" + pricing_model = build_pricing_model_from_aggregate( + sample_pricing_aggregate_content + ) + + # Check that we got ProductPricing objects for each type + assert len(pricing_model) == 3 + assert ProductPriceType.STORAGE in pricing_model + assert ProductPriceType.PROGRAM in pricing_model + assert ProductPriceType.INSTANCE in pricing_model + + # Check that each is a ProductPricing object + for price_type, pricing in pricing_model.items(): + assert isinstance(pricing, ProductPricing) + assert pricing.type == price_type + + def test_build_pricing_model_with_missing_types(self): + """Test building pricing model with some missing product types.""" + partial_content = { + ProductPriceType.STORAGE: {"price": {"storage": {"holding": "0.5"}}} + } + + pricing_model = build_pricing_model_from_aggregate(partial_content) + + assert len(pricing_model) == 1 + assert ProductPriceType.STORAGE in pricing_model + assert ProductPriceType.PROGRAM not in pricing_model + + def test_build_pricing_model_with_invalid_data(self): + """Test building pricing model with invalid pricing data.""" + invalid_content = { + ProductPriceType.STORAGE: { + "invalid": "data" # Missing required "price" key + } + } + + # Should handle the error gracefully and return empty model + pricing_model = build_pricing_model_from_aggregate(invalid_content) + assert len(pricing_model) == 0 + + +class TestBuildDefaultPricingModel: + """Tests for building the default pricing model.""" + + def test_build_default_pricing_model(self): + """Test building the default pricing model from constants.""" + pricing_model = build_default_pricing_model() + + # Should contain all the expected product types from DEFAULT_PRICE_AGGREGATE + expected_types = [ + ProductPriceType.PROGRAM, + ProductPriceType.STORAGE, + ProductPriceType.INSTANCE, + ProductPriceType.PROGRAM_PERSISTENT, + ProductPriceType.INSTANCE_GPU_PREMIUM, + ProductPriceType.INSTANCE_CONFIDENTIAL, + ProductPriceType.INSTANCE_GPU_STANDARD, + ProductPriceType.WEB3_HOSTING, + ] + + for price_type in expected_types: + assert price_type in pricing_model + assert isinstance(pricing_model[price_type], ProductPricing) + + +class TestGetPricingAggregateHistory: + """Tests for retrieving pricing aggregate history.""" + + def test_get_pricing_aggregate_history_empty(self, session_factory): + """Test getting pricing history when no elements exist.""" + with session_factory() as session: + history = get_pricing_aggregate_history(session) + assert len(history) == 0 + + def test_get_pricing_aggregate_history_with_elements( + self, session_factory, pricing_aggregate_elements + ): + """Test getting pricing history with existing elements.""" + with session_factory() as session: + history = get_pricing_aggregate_history(session) + + assert len(history) == 3 + + # Should be ordered chronologically + assert history[0].creation_datetime < history[1].creation_datetime + assert history[1].creation_datetime < history[2].creation_datetime + + # Check content + assert ProductPriceType.STORAGE in history[0].content + assert ProductPriceType.PROGRAM in history[1].content + assert ProductPriceType.STORAGE in history[2].content + assert ProductPriceType.INSTANCE in history[2].content + + +class TestGetPricingTimeline: + """Tests for getting the pricing timeline.""" + + def test_get_pricing_timeline_empty(self, session_factory): + """Test getting pricing timeline when no aggregate elements exist.""" + with session_factory() as session: + timeline = get_pricing_timeline(session) + + # Should have at least the default pricing + assert len(timeline) == 1 + timestamp, pricing_model = timeline[0] + + # Should use minimum datetime for default pricing + assert timestamp == dt.datetime.min.replace(tzinfo=dt.timezone.utc) + assert isinstance(pricing_model, dict) + assert ProductPriceType.STORAGE in pricing_model + + def test_get_pricing_timeline_with_elements( + self, session_factory, pricing_aggregate_elements + ): + """Test getting pricing timeline with aggregate elements.""" + with session_factory() as session: + timeline = get_pricing_timeline(session) + + # Should have default + 3 pricing updates + assert len(timeline) == 4 + + # Check chronological order + for i in range(len(timeline) - 1): + assert timeline[i][0] <= timeline[i + 1][0] + + # Check content evolution + default_timestamp, default_model = timeline[0] + first_timestamp, first_model = timeline[1] + second_timestamp, second_model = timeline[2] + third_timestamp, third_model = timeline[3] + + # First update: only storage pricing + assert ProductPriceType.STORAGE in first_model + storage_pricing_1 = first_model[ProductPriceType.STORAGE] + assert storage_pricing_1.price.storage.holding == Decimal("0.2") + + # Second update: storage + program pricing (cumulative) + assert ProductPriceType.STORAGE in second_model + assert ProductPriceType.PROGRAM in second_model + storage_pricing_2 = second_model[ProductPriceType.STORAGE] + assert storage_pricing_2.price.storage.holding == Decimal( + "0.2" + ) # Still from first update + + # Third update: updated storage + program + instance pricing (cumulative) + assert ProductPriceType.STORAGE in third_model + assert ProductPriceType.PROGRAM in third_model + assert ProductPriceType.INSTANCE in third_model + storage_pricing_3 = third_model[ProductPriceType.STORAGE] + assert storage_pricing_3.price.storage.holding == Decimal( + "0.3" + ) # Updated value + + def test_pricing_timeline_cumulative_merging(self, session_factory): + """Test that pricing timeline properly merges cumulative changes.""" + base_time = dt.datetime(2024, 1, 1, 12, 0, 0, tzinfo=dt.timezone.utc) + + # Create elements that update different parts of pricing + element1 = AggregateElementDb( + item_hash="test_1", + key=PRICE_AGGREGATE_KEY, + owner=PRICE_AGGREGATE_OWNER, + content={ + ProductPriceType.STORAGE: {"price": {"storage": {"holding": "1.0"}}}, + ProductPriceType.PROGRAM: { + "price": { + "storage": {"holding": "0.1"}, + "compute_unit": {"holding": "100"}, + }, + "compute_unit": {"vcpus": 1, "disk_mib": 1024, "memory_mib": 1024}, + }, + }, + creation_datetime=base_time + dt.timedelta(hours=1), + ) + + # Second element only updates storage, should preserve program settings + element2 = AggregateElementDb( + item_hash="test_2", + key=PRICE_AGGREGATE_KEY, + owner=PRICE_AGGREGATE_OWNER, + content={ + ProductPriceType.STORAGE: { + "price": {"storage": {"holding": "2.0"}} # Updated price + } + }, + creation_datetime=base_time + dt.timedelta(hours=2), + ) + + with session_factory() as session: + session.add(element1) + session.add(element2) + session.commit() + + timeline = get_pricing_timeline(session) + + # Should have default + 2 updates + assert len(timeline) == 3 + + # Check final state has both storage and program, with updated storage price + final_timestamp, final_model = timeline[2] + + assert ProductPriceType.STORAGE in final_model + assert ProductPriceType.PROGRAM in final_model + + # Storage should have updated price + storage_pricing = final_model[ProductPriceType.STORAGE] + assert storage_pricing.price.storage.holding == Decimal("2.0") + + # Program should still have original settings + program_pricing = final_model[ProductPriceType.PROGRAM] + assert program_pricing.price.storage.holding == Decimal("0.1") + assert program_pricing.price.compute_unit.holding == Decimal("100") + + +class TestPricingTimelineIntegration: + """Integration tests for pricing timeline with real message types.""" + + @pytest.fixture + def sample_instance_content(self): + """Sample instance content for testing.""" + return InstanceContent.model_validate( + { + "time": 1701099523.849, + "rootfs": { + "parent": { + "ref": "549ec451d9b099cad112d4aaa2c00ac40fb6729a92ff252ff22eef0b5c3cb613", + "use_latest": True, + }, + "size_mib": 20480, + "persistence": "host", + }, + "address": "0xTest", + "volumes": [], + "metadata": {"name": "Test Instance"}, + "resources": {"vcpus": 1, "memory": 2048, "seconds": 30}, + "allow_amend": False, + "environment": {"internet": True, "aleph_api": True}, + } + ) + + def test_pricing_timeline_with_message_processing( + self, session_factory, pricing_aggregate_elements, sample_instance_content + ): + """Test that pricing timeline can be used for historical message cost calculation.""" + with session_factory() as session: + timeline = get_pricing_timeline(session) + + # Simulate processing a message at different points in time + message_time_1 = dt.datetime( + 2024, 1, 1, 13, 30, 0, tzinfo=dt.timezone.utc + ) # Between first and second update + message_time_2 = dt.datetime( + 2024, 1, 1, 15, 30, 0, tzinfo=dt.timezone.utc + ) # After all updates + + # Find applicable pricing for each message time + pricing_1 = None + pricing_2 = None + + for timestamp, pricing_model in timeline: + if timestamp <= message_time_1: + pricing_1 = pricing_model + if timestamp <= message_time_2: + pricing_2 = pricing_model + + # At time 1, should have storage pricing but not instance pricing + assert pricing_1 is not None + assert ProductPriceType.STORAGE in pricing_1 + # Instance pricing not added until third update + assert ProductPriceType.INSTANCE not in pricing_1 + + # At time 2, should have both storage and instance pricing + assert pricing_2 is not None + assert ProductPriceType.STORAGE in pricing_2 + assert ProductPriceType.INSTANCE in pricing_2 + + # Storage pricing should be different between the two time points + if ( + ProductPriceType.STORAGE in pricing_1 + and ProductPriceType.STORAGE in pricing_2 + ): + storage_1 = pricing_1[ProductPriceType.STORAGE] + storage_2 = pricing_2[ProductPriceType.STORAGE] + # Should have different prices due to the third update + assert ( + storage_1.price.storage.holding != storage_2.price.storage.holding + )