Skip to content

Commit 2b8f7a3

Browse files
authored
Remove extra validation checks and add some tests for detect and evaluate (#59)
* Remove extra validation checks * Add tests for detect and evaluate * Update run.sh for running tests * Add tests for Analyze functions * Remove AnalyzeProd, AnalyzeEval * Bump version * Add low-level API tests * Update the version
1 parent 1c89174 commit 2b8f7a3

File tree

12 files changed

+1774
-367
lines changed

12 files changed

+1774
-367
lines changed

aimon/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,4 @@
8282
pass
8383

8484
from .decorators.detect import Detect
85-
from .decorators.evaluate import AnalyzeEval, AnalyzeProd, Application, Model, evaluate, EvaluateResponse
85+
from .decorators.evaluate import Application, Model, evaluate, EvaluateResponse

aimon/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
22

33
__title__ = "aimon"
4-
__version__ = "0.9.2"
4+
__version__ = "0.10.0"

aimon/decorators/detect.py

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,8 @@ def __init__(self, values_returned, api_key=None, config=None, async_mode=False,
151151
self.client = Client(auth_header="Bearer {}".format(api_key))
152152
self.config = config if config else self.DEFAULT_CONFIG
153153
self.values_returned = values_returned
154-
if self.values_returned is None or len(self.values_returned) == 0:
155-
raise ValueError("values_returned by the decorated function must be specified")
156-
if "context" not in self.values_returned:
157-
raise ValueError("values_returned must contain 'context'")
154+
if self.values_returned is None or not hasattr(self.values_returned, '__iter__') or len(self.values_returned) == 0:
155+
raise ValueError("values_returned must be specified and be an iterable")
158156
self.async_mode = async_mode
159157
self.publish = publish
160158
if self.async_mode:
@@ -178,29 +176,7 @@ def wrapper(*args, **kwargs):
178176
result = (result,)
179177

180178
# Create a dictionary mapping output names to results
181-
result_dict = {name: value for name, value in zip(self.values_returned, result)}
182-
183-
aimon_payload = {}
184-
if 'generated_text' in result_dict:
185-
aimon_payload['generated_text'] = result_dict['generated_text']
186-
else:
187-
raise ValueError("Result of the wrapped function must contain 'generated_text'")
188-
if 'context' in result_dict:
189-
aimon_payload['context'] = result_dict['context']
190-
else:
191-
raise ValueError("Result of the wrapped function must contain 'context'")
192-
if 'user_query' in result_dict:
193-
aimon_payload['user_query'] = result_dict['user_query']
194-
if 'instructions' in result_dict:
195-
aimon_payload['instructions'] = result_dict['instructions']
196-
197-
if 'retrieval_relevance' in self.config:
198-
if 'task_definition' in result_dict:
199-
aimon_payload['task_definition'] = result_dict['task_definition']
200-
else:
201-
raise ValueError( "When retrieval_relevance is specified in the config, "
202-
"'task_definition' must be present in the result of the wrapped function.")
203-
179+
aimon_payload = {name: value for name, value in zip(self.values_returned, result)}
204180

205181
aimon_payload['config'] = self.config
206182
aimon_payload['publish'] = self.publish

aimon/decorators/evaluate.py

Lines changed: 5 additions & 272 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,6 @@ def evaluate(
225225
# Validata headers to be non-empty and contain atleast the context_docs column
226226
if not headers:
227227
raise ValueError("Headers must be a non-empty list")
228-
if "context_docs" not in headers:
229-
raise ValueError("Headers must contain the column 'context_docs'")
230228

231229
# Create application and models
232230
am_app = client.applications.create(
@@ -276,287 +274,22 @@ def evaluate(
276274
if ag not in record:
277275
raise ValueError("Dataset record must contain the column '{}' as specified in the 'headers'"
278276
" argument in the decorator".format(ag))
279-
280-
if "context_docs" not in record:
281-
raise ValueError("Dataset record must contain the column 'context_docs'")
282277

283-
_context = record['context_docs'] if isinstance(record['context_docs'], list) else [record['context_docs']]
284278
# Construct the payload for the analysis
285279
payload = {
280+
**record,
281+
"config": config,
286282
"application_id": am_app.id,
287283
"version": am_app.version,
288-
"context_docs": [d for d in _context],
289284
"evaluation_id": am_eval.id,
290285
"evaluation_run_id": eval_run.id,
291286
}
292-
if "prompt" in record and record["prompt"]:
293-
payload["prompt"] = record["prompt"]
294-
if "user_query" in record and record["user_query"]:
295-
payload["user_query"] = record["user_query"]
296-
if "output" in record and record["output"]:
297-
payload["output"] = record["output"]
298-
if "instruction_adherence" in config and "instructions" not in record:
299-
raise ValueError("When instruction_adherence is specified in the config, "
300-
"'instructions' must be present in the dataset")
301-
if "instructions" in record and "instruction_adherence" in config:
302-
# Only pass instructions if instruction_adherence is specified in the config
303-
payload["instructions"] = record["instructions"] or ""
304-
305-
if "retrieval_relevance" in config:
306-
if "task_definition" in record:
307-
payload["task_definition"] = record["task_definition"]
308-
else:
309-
raise ValueError( "When retrieval_relevance is specified in the config, "
310-
"'task_definition' must be present in the dataset")
311-
312-
payload["config"] = config
287+
if "instructions" in payload and not payload["instructions"]:
288+
payload["instructions"] = ""
289+
313290
results.append(EvaluateResponse(record['output'], client.analyze.create(body=[payload])))
314291

315292
return results
316293

317-
class AnalyzeBase:
318-
DEFAULT_CONFIG = {'hallucination': {'detector_name': 'default'}}
319-
320-
def __init__(self, application, model, api_key=None, config=None):
321-
"""
322-
:param application: An Application object
323-
:param model: A Model object
324-
:param api_key: The API key to use for the Aimon client
325-
"""
326-
self.client = Client(auth_header="Bearer {}".format(api_key))
327-
self.application = application
328-
self.model = model
329-
self.config = config if config else self.DEFAULT_CONFIG
330-
self.initialize()
331-
332-
def initialize(self):
333-
# Create or retrieve the model
334-
self._am_model = self.client.models.create(
335-
name=self.model.name,
336-
type=self.model.model_type,
337-
description="This model is named {} and is of type {}".format(self.model.name, self.model.model_type),
338-
metadata=self.model.metadata
339-
)
340-
341-
# Create or retrieve the application
342-
self._am_app = self.client.applications.create(
343-
name=self.application.name,
344-
model_name=self._am_model.name,
345-
stage=self.application.stage,
346-
type=self.application.type,
347-
metadata=self.application.metadata
348-
)
349-
350-
351-
class AnalyzeEval(AnalyzeBase):
352-
353-
def __init__(self, application, model, evaluation_name, dataset_collection_name, headers,
354-
api_key=None, eval_tags=None, config=None):
355-
"""
356-
The wrapped function should have a signature as follows:
357-
def func(context_docs, user_query, prompt, instructions *args, **kwargs):
358-
# Your code here
359-
return output
360-
[Required] The first argument must be a 'context_docs' which is of type List[str].
361-
[Required] The second argument must be a 'user_query' which is of type str.
362-
[Optional] The third argument must be a 'prompt' which is of type str
363-
[Optional] If an 'instructions' column is present in the dataset, then the fourth argument
364-
must be 'instructions' which is of type str
365-
[Optional] If an 'output' column is present in the dataset, then the fifth argument
366-
must be 'output' which is of type str
367-
Return: The function must return an output which is of type str
368-
369-
:param application: An Application object
370-
:param model: A Model object
371-
:param evaluation_name: The name of the evaluation
372-
:param dataset_collection_name: The name of the dataset collection
373-
:param headers: A list containing the headers to be used for the evaluation
374-
:param api_key: The API key to use for the AIMon client
375-
:param eval_tags: A list of tags to associate with the evaluation
376-
:param config: A dictionary containing the AIMon configuration for the evaluation
377-
378-
379-
"""
380-
super().__init__(application, model, api_key, config)
381-
warnings.warn(
382-
f"{self.__class__.__name__} is deprecated and will be removed in a later release. Please use the evaluate method instead.",
383-
DeprecationWarning,
384-
stacklevel=2
385-
)
386-
self.headers = headers
387-
self.evaluation_name = evaluation_name
388-
self.dataset_collection_name = dataset_collection_name
389-
self.eval_tags = eval_tags
390-
self.eval_initialize()
391-
392-
def eval_initialize(self):
393-
if self.dataset_collection_name is None:
394-
raise ValueError("Dataset collection name must be provided for running an evaluation.")
395-
396-
# Create or retrieve the dataset collection
397-
self._am_dataset_collection = self.client.datasets.collection.retrieve(name=self.dataset_collection_name)
398-
399-
# Create or retrieve the evaluation
400-
self._eval = self.client.evaluations.create(
401-
name=self.evaluation_name,
402-
application_id=self._am_app.id,
403-
model_id=self._am_model.id,
404-
dataset_collection_id=self._am_dataset_collection.id
405-
)
406-
407-
def _run_eval(self, func, args, kwargs):
408-
# Create an evaluation run
409-
eval_run = self.client.evaluations.run.create(
410-
evaluation_id=self._eval.id,
411-
metrics_config=self.config,
412-
)
413-
# Get all records from the datasets
414-
dataset_collection_records = []
415-
for dataset_id in self._am_dataset_collection.dataset_ids:
416-
dataset_records = self.client.datasets.records.list(sha=dataset_id)
417-
dataset_collection_records.extend(dataset_records)
418-
results = []
419-
for record in dataset_collection_records:
420-
# The record must contain the context_docs and user_query fields.
421-
# The prompt, output and instructions fields are optional.
422-
# Inspect the record and call the function with the appropriate arguments
423-
arguments = []
424-
for ag in self.headers:
425-
if ag not in record:
426-
raise ValueError("Record must contain the column '{}' as specified in the 'headers'"
427-
" argument in the decorator".format(ag))
428-
arguments.append(record[ag])
429-
# Inspect the function signature to ensure that it accepts the correct arguments
430-
sig = inspect.signature(func)
431-
params = sig.parameters
432-
if len(params) < len(arguments):
433-
raise ValueError("Function must accept at least {} arguments".format(len(arguments)))
434-
# Ensure that the first len(arguments) parameters are named correctly
435-
param_names = list(params.keys())
436-
if param_names[:len(arguments)] != self.headers:
437-
raise ValueError("Function arguments must be named as specified by the 'headers' argument: {}".format(
438-
self.headers))
439-
440-
result = func(*arguments, *args, **kwargs)
441-
_context = record['context_docs'] if isinstance(record['context_docs'], list) else [record['context_docs']]
442-
payload = {
443-
"application_id": self._am_app.id,
444-
"version": self._am_app.version,
445-
"prompt": record['prompt'] or "",
446-
"user_query": record['user_query'] or "",
447-
"context_docs": [d for d in _context],
448-
"output": result,
449-
"evaluation_id": self._eval.id,
450-
"evaluation_run_id": eval_run.id,
451-
}
452-
if "instruction_adherence" in self.config and "instructions" not in record:
453-
raise ValueError("When instruction_adherence is specified in the config, "
454-
"'instructions' must be present in the dataset")
455-
if "instructions" in record and "instruction_adherence" in self.config:
456-
# Only pass instructions if instruction_adherence is specified in the config
457-
payload["instructions"] = record["instructions"] or ""
458-
459-
if "retrieval_relevance" in self.config:
460-
if "task_definition" in record:
461-
payload["task_definition"] = record["task_definition"]
462-
else:
463-
raise ValueError( "When retrieval_relevance is specified in the config, "
464-
"'task_definition' must be present in the dataset")
465-
466-
payload["config"] = self.config
467-
results.append((result, self.client.analyze.create(body=[payload])))
468-
return results
469-
470-
def __call__(self, func):
471-
@wraps(func)
472-
def wrapper(*args, **kwargs):
473-
return self._run_eval(func, args, kwargs)
474-
475-
return wrapper
476-
477-
478-
class AnalyzeProd(AnalyzeBase):
479-
480-
def __init__(self, application, model, values_returned, api_key=None, config=None):
481-
"""
482-
The wrapped function should return a tuple of values in the order specified by values_returned. In addition,
483-
the wrapped function should accept a parameter named eval_obj which will be used when using this decorator
484-
in evaluation mode.
485-
486-
:param application: An Application object
487-
:param model: A Model object
488-
:param values_returned: A list of values in the order returned by the decorated function
489-
Acceptable values are 'generated_text', 'context', 'user_query', 'instructions'
490-
"""
491-
application.stage = "production"
492-
super().__init__(application, model, api_key, config)
493-
warnings.warn(
494-
f"{self.__class__.__name__} is deprecated and will be removed in a later release. Please use Detect with async=True instead.",
495-
DeprecationWarning,
496-
stacklevel=2
497-
)
498-
self.values_returned = values_returned
499-
if self.values_returned is None or len(self.values_returned) == 0:
500-
raise ValueError("Values returned by the decorated function must be specified")
501-
if "generated_text" not in self.values_returned:
502-
raise ValueError("values_returned must contain 'generated_text'")
503-
if "context" not in self.values_returned:
504-
raise ValueError("values_returned must contain 'context'")
505-
if "instruction_adherence" in self.config and "instructions" not in self.values_returned:
506-
raise ValueError(
507-
"When instruction_adherence is specified in the config, 'instructions' must be returned by the decorated function")
508-
509-
if "retrieval_relevance" in self.config and "task_definition" not in self.values_returned:
510-
raise ValueError( "When retrieval_relevance is specified in the config, "
511-
"'task_definition' must be returned by the decorated function")
512-
513-
if "instructions" in self.values_returned and "instruction_adherence" not in self.config:
514-
raise ValueError(
515-
"instruction_adherence must be specified in the config for returning 'instructions' by the decorated function")
516-
self.config = config if config else self.DEFAULT_CONFIG
517-
518-
def _run_production_analysis(self, func, args, kwargs):
519-
result = func(*args, **kwargs)
520-
if result is None:
521-
raise ValueError("Result must be returned by the decorated function")
522-
# Handle the case where the result is a single value
523-
if not isinstance(result, tuple):
524-
result = (result,)
525-
526-
# Create a dictionary mapping output names to results
527-
result_dict = {name: value for name, value in zip(self.values_returned, result)}
528-
529-
if "generated_text" not in result_dict:
530-
raise ValueError("Result of the wrapped function must contain 'generated_text'")
531-
if "context" not in result_dict:
532-
raise ValueError("Result of the wrapped function must contain 'context'")
533-
_context = result_dict['context'] if isinstance(result_dict['context'], list) else [result_dict['context']]
534-
aimon_payload = {
535-
"application_id": self._am_app.id,
536-
"version": self._am_app.version,
537-
"output": result_dict['generated_text'],
538-
"context_docs": _context,
539-
"user_query": result_dict["user_query"] if 'user_query' in result_dict else "No User Query Specified",
540-
"prompt": result_dict['prompt'] if 'prompt' in result_dict else "No Prompt Specified",
541-
}
542-
if 'instructions' in result_dict:
543-
aimon_payload['instructions'] = result_dict['instructions']
544-
if 'actual_request_timestamp' in result_dict:
545-
aimon_payload["actual_request_timestamp"] = result_dict['actual_request_timestamp']
546-
if 'task_definition' in result_dict:
547-
aimon_payload['task_definition'] = result_dict['task_definition']
548-
549-
aimon_payload['config'] = self.config
550-
aimon_response = self.client.analyze.create(body=[aimon_payload])
551-
return result + (aimon_response,)
552-
553-
def __call__(self, func):
554-
@wraps(func)
555-
def wrapper(*args, **kwargs):
556-
# Production mode, run the provided args through the user function
557-
return self._run_production_analysis(func, args, kwargs)
558-
559-
return wrapper
560-
561294

562295

0 commit comments

Comments
 (0)