diff --git a/aimon/decorators/detect.py b/aimon/decorators/detect.py index 99c7acf..29ac6fd 100644 --- a/aimon/decorators/detect.py +++ b/aimon/decorators/detect.py @@ -91,6 +91,8 @@ class Detect: The name of the application to use when publish is True. model_name : str, optional The name of the model to use when publish is True. + must_compute : str, optional + Indicates the computation strategy. Must be either 'all_or_none' or 'ignore_failures'. Default is 'all_or_none'. Example: -------- @@ -133,7 +135,7 @@ class Detect: """ DEFAULT_CONFIG = {'hallucination': {'detector_name': 'default'}} - def __init__(self, values_returned, api_key=None, config=None, async_mode=False, publish=False, application_name=None, model_name=None): + def __init__(self, values_returned, api_key=None, config=None, async_mode=False, publish=False, application_name=None, model_name=None, must_compute='all_or_none'): """ :param values_returned: A list of values in the order returned by the decorated function Acceptable values are 'generated_text', 'context', 'user_query', 'instructions' @@ -144,6 +146,7 @@ def __init__(self, values_returned, api_key=None, config=None, async_mode=False, :param publish: Boolean, if True, the payload will be published to AIMon and can be viewed on the AIMon UI. Default is False. :param application_name: The name of the application to use when publish is True :param model_name: The name of the model to use when publish is True + :param must_compute: String, indicates the computation strategy. Must be either 'all_or_none' or 'ignore_failures'. Default is 'all_or_none'. """ api_key = os.getenv('AIMON_API_KEY') if not api_key else api_key if api_key is None: @@ -163,8 +166,15 @@ def __init__(self, values_returned, api_key=None, config=None, async_mode=False, if model_name is None: raise ValueError("Model name must be provided if publish is True") + # Validate must_compute parameter + if not isinstance(must_compute, str): + raise ValueError("`must_compute` must be a string value") + if must_compute not in ['all_or_none', 'ignore_failures']: + raise ValueError("`must_compute` must be either 'all_or_none' or 'ignore_failures'") + self.must_compute = must_compute + self.application_name = application_name - self.model_name = model_name + self.model_name = model_name def __call__(self, func): @wraps(func) @@ -181,6 +191,7 @@ def wrapper(*args, **kwargs): aimon_payload['config'] = self.config aimon_payload['publish'] = self.publish aimon_payload['async_mode'] = self.async_mode + aimon_payload['must_compute'] = self.must_compute # Include application_name and model_name if publishing if self.publish: diff --git a/aimon/types/inference_detect_params.py b/aimon/types/inference_detect_params.py index fe98ab0..be05eb5 100644 --- a/aimon/types/inference_detect_params.py +++ b/aimon/types/inference_detect_params.py @@ -47,7 +47,6 @@ class BodyConfigInstructionAdherence(TypedDict, total=False): class BodyConfigToxicity(TypedDict, total=False): detector_name: Literal["default"] - class BodyConfig(TypedDict, total=False): completeness: BodyConfigCompleteness @@ -61,7 +60,6 @@ class BodyConfig(TypedDict, total=False): toxicity: BodyConfigToxicity - class Body(TypedDict, total=False): context: Required[Union[List[str], str]] """Context as an array of strings or a single string""" @@ -81,6 +79,9 @@ class Body(TypedDict, total=False): model_name: str """The model name for publishing metrics for an application.""" + must_compute: str + """Indicates the computation strategy. Must be either 'all_or_none' or 'ignore_failures'.""" + publish: bool """Indicates whether to publish metrics.""" diff --git a/tests/test_detect.py b/tests/test_detect.py index ad0a07f..8fc8dfc 100644 --- a/tests/test_detect.py +++ b/tests/test_detect.py @@ -824,3 +824,158 @@ def test_evaluate_with_new_model(self): import os if os.path.exists(dataset_path): os.remove(dataset_path) + + def test_must_compute_validation(self): + """Test that the must_compute parameter is properly validated.""" + print("\n=== Testing must_compute validation ===") + + # Test config with both hallucination and completeness + test_config = { + "hallucination": { + "detector_name": "default" + }, + "completeness": { + "detector_name": "default" + } + } + print(f"Test Config: {test_config}") + + # Test valid values + valid_values = ['all_or_none', 'ignore_failures'] + print(f"Testing valid must_compute values: {valid_values}") + + for value in valid_values: + print(f"Testing valid must_compute value: {value}") + detect = Detect( + values_returned=["context", "generated_text"], + api_key=self.api_key, + config=test_config, + must_compute=value + ) + assert detect.must_compute == value + print(f"✅ Successfully validated must_compute value: {value}") + + # Test invalid string value + invalid_string_value = "invalid_value" + print(f"Testing invalid must_compute string value: {invalid_string_value}") + try: + Detect( + values_returned=["context", "generated_text"], + api_key=self.api_key, + config=test_config, + must_compute=invalid_string_value + ) + print("❌ ERROR: Expected ValueError but none was raised - This should not happen") + assert False, "Expected ValueError for invalid string value" + except ValueError as e: + print(f"✅ Successfully caught ValueError for invalid string: {str(e)}") + assert "`must_compute` must be either 'all_or_none' or 'ignore_failures'" in str(e) + + # Test non-string value + non_string_value = 123 + print(f"Testing non-string must_compute value: {non_string_value}") + try: + Detect( + values_returned=["context", "generated_text"], + api_key=self.api_key, + config=test_config, + must_compute=non_string_value + ) + print("❌ ERROR: Expected ValueError but none was raised - This should not happen") + assert False, "Expected ValueError for non-string value" + except ValueError as e: + print(f"✅ Successfully caught ValueError for non-string: {str(e)}") + assert "`must_compute` must be a string value" in str(e) + + # Test default value + print("Testing default must_compute value: default") + detect_default = Detect( + values_returned=["context", "generated_text"], + api_key=self.api_key, + config=test_config + ) + assert detect_default.must_compute == 'all_or_none' + print(f"✅ Successfully validated default must_compute value: {detect_default.must_compute}") + + print("🎉 Result: must_compute validation working correctly") + + def test_must_compute_with_actual_service(self): + """Test must_compute functionality with actual service calls.""" + print("\n=== Testing must_compute with actual service ===") + + # Test config with both hallucination and completeness + test_config = { + "hallucination": { + "detector_name": "default" + }, + "completeness": { + "detector_name": "default" + } + } + print(f"Test Config: {test_config}") + + # Test both must_compute values + for must_compute_value in ['all_or_none', 'ignore_failures']: + print(f"\n--- Testing must_compute: {must_compute_value} ---") + + detect = Detect( + values_returned=["context", "generated_text", "user_query"], + api_key=self.api_key, + config=test_config, + must_compute=must_compute_value + ) + + @detect + def generate_summary(context, query): + generated_text = f"Summary of {context} based on query: {query}" + return context, generated_text, query + + # Test data + context = "Machine learning is a subset of artificial intelligence that enables computers to learn without being explicitly programmed." + query = "What is machine learning?" + + print(f"Input Context: {context}") + print(f"Input Query: {query}") + print(f"Must Compute: {must_compute_value}") + + try: + # Call the decorated function + context_ret, generated_text, query_ret, result = generate_summary(context, query) + + print(f"✅ API Call Successful!") + print(f"Status Code: {result.status}") + print(f"Generated Text: {generated_text}") + + # Display response details + if hasattr(result.detect_response, 'hallucination'): + hallucination = result.detect_response.hallucination + print(f"Hallucination Score: {hallucination.get('score', 'N/A')}") + print(f"Is Hallucinated: {hallucination.get('is_hallucinated', 'N/A')}") + + if hasattr(result.detect_response, 'completeness'): + completeness = result.detect_response.completeness + print(f"Completeness Score: {completeness.get('score', 'N/A')}") + + # Show the full response structure + print(f"Response Object Type: {type(result.detect_response)}") + if hasattr(result.detect_response, '__dict__'): + print(f"Response Attributes: {list(result.detect_response.__dict__.keys())}") + + except Exception as e: + error_message = str(e) + print(f"API Call Result: {error_message}") + print(f"Error Type: {type(e).__name__}") + + # For all_or_none, 503 is expected when services are unavailable + if must_compute_value == 'all_or_none' and '503' in error_message: + print("✅ Expected behavior: all_or_none returns 503 when services unavailable") + # For ignore_failures, we expect success or different error handling + elif must_compute_value == 'ignore_failures': + if '503' in error_message: + print("❌ Unexpected: ignore_failures should handle service unavailability") + else: + print("✅ Expected behavior: ignore_failures handled the error appropriately") + else: + print(f"❌ Unexpected error for {must_compute_value}: {error_message}") + + print("\n🎉 All must_compute service tests completed!")