diff --git a/core_backend/app/config.py b/core_backend/app/config.py index f762c2ae2..501417626 100644 --- a/core_backend/app/config.py +++ b/core_backend/app/config.py @@ -33,7 +33,8 @@ LITELLM_MODEL_GENERATION = os.environ.get( "LITELLM_MODEL_GENERATION", "openai/generate-gemini-response", - # "LITELLM_MODEL_GENERATION", "openai/generate-response" + # "LITELLM_MODEL_GENERATION", + # "openai/generate-response", ) LITELLM_MODEL_LANGUAGE_DETECT = os.environ.get( "LITELLM_MODEL_LANGUAGE_DETECT", "openai/detect-language" @@ -64,6 +65,7 @@ ALIGN_SCORE_METHOD = os.environ.get("ALIGN_SCORE_METHOD", "LLM") # if AlignScore, set ALIGN_SCORE_API. If LLM, set LITELLM_MODEL_ALIGNSCORE above. ALIGN_SCORE_API = os.environ.get("ALIGN_SCORE_API", "") +ALIGN_SCORE_N_RETRIES = os.environ.get("ALIGN_SCORE_N_RETRIES", 1) # Backend paths BACKEND_ROOT_PATH = os.environ.get("BACKEND_ROOT_PATH", "") diff --git a/core_backend/app/llm_call/process_output.py b/core_backend/app/llm_call/process_output.py index 4a2eaf650..3f550e2f9 100644 --- a/core_backend/app/llm_call/process_output.py +++ b/core_backend/app/llm_call/process_output.py @@ -171,7 +171,7 @@ async def wrapper( response = await func(query_refined, response, *args, **kwargs) - if not kwargs.get("generate_llm_response", False): + if not query_refined.generate_llm_response: return response metadata = create_langfuse_metadata( diff --git a/core_backend/app/question_answer/routers.py b/core_backend/app/question_answer/routers.py index 02346c72e..7a72abbd0 100644 --- a/core_backend/app/question_answer/routers.py +++ b/core_backend/app/question_answer/routers.py @@ -5,13 +5,14 @@ import os from typing import Tuple +import backoff from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status from fastapi.responses import JSONResponse from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from ..auth.dependencies import authenticate_key, rate_limiter -from ..config import SPEECH_ENDPOINT +from ..config import ALIGN_SCORE_N_RETRIES, SPEECH_ENDPOINT from ..contents.models import ( get_similar_content_async, increment_query_count, @@ -42,6 +43,7 @@ ) from .schemas import ( ContentFeedback, + ErrorType, QueryBase, QueryRefined, QueryResponse, @@ -216,6 +218,17 @@ async def search( contents=response.search_results, asession=asession, ) + if is_unable_to_generate_response(response): + failure_reason = response.debug_info["factual_consistency"] + response = await retry_search( + query_refined=user_query_refined_template, + response=response_template, + user_id=user_db.user_id, + n_similar=int(N_TOP_CONTENT), + asession=asession, + exclude_archived=True, + ) + response.debug_info["past_failure"] = failure_reason if type(response) is QueryResponse: return response @@ -234,8 +247,8 @@ async def search( @classify_safety__before @translate_question__before @paraphrase_question__before -@generate_llm_response__after @check_align_score__after +@generate_llm_response__after async def search_base( query_refined: QueryRefined, response: QueryResponse, @@ -288,6 +301,39 @@ async def search_base( return response +def is_unable_to_generate_response(response: QueryResponse) -> bool: + """ + Check if the response is of type QueryResponseError and caused + by low alignment score. + """ + return ( + isinstance(response, QueryResponseError) + and response.error_type == ErrorType.ALIGNMENT_TOO_LOW + ) + + +@backoff.on_predicate( + backoff.expo, + max_tries=int(ALIGN_SCORE_N_RETRIES), + predicate=is_unable_to_generate_response, +) +async def retry_search( + query_refined: QueryRefined, + response: QueryResponse | QueryResponseError, + user_id: int, + n_similar: int, + asession: AsyncSession, + exclude_archived: bool = True, +) -> QueryResponse | QueryResponseError: + """ + Retry wrapper for search_base. + """ + + return await search_base( + query_refined, response, user_id, n_similar, asession, exclude_archived + ) + + async def get_user_query_and_response( user_id: int, user_query: QueryBase, asession: AsyncSession ) -> Tuple[QueryDB, QueryRefined, QueryResponse]: diff --git a/core_backend/requirements.txt b/core_backend/requirements.txt index 6e6b15134..667d5ab9d 100644 --- a/core_backend/requirements.txt +++ b/core_backend/requirements.txt @@ -19,3 +19,4 @@ types-openpyxl==3.1.4.20240621 redis==5.0.8 python-dateutil==2.8.2 gTTS==2.5.1 +backoff==2.2.1