From 69ab0f94ccb84cdf10f8deff0de97211a6796cc0 Mon Sep 17 00:00:00 2001 From: Konie Date: Thu, 25 Jan 2024 14:48:07 +0800 Subject: [PATCH] Optimized async generation response when worker queue is full --- fooocusapi/api.py | 102 ++++++++++++++++++++++++++-------------- fooocusapi/api_utils.py | 90 +++++++++++++++++------------------ fooocusapi/models.py | 2 +- 3 files changed, 113 insertions(+), 81 deletions(-) diff --git a/fooocusapi/api.py b/fooocusapi/api.py index 063ebb5..88dc47a 100644 --- a/fooocusapi/api.py +++ b/fooocusapi/api.py @@ -9,10 +9,10 @@ from fooocusapi.args import args from fooocusapi.models import * -from fooocusapi.api_utils import generation_output, req_to_params +from fooocusapi.api_utils import req_to_params, generate_async_output, generate_streaming_output, generate_image_result_output import fooocusapi.file_utils as file_utils from fooocusapi.parameters import GenerationFinishReason, ImageGenerationResult -from fooocusapi.task_queue import TaskType +from fooocusapi.task_queue import QueueTask, TaskType from fooocusapi.worker import process_generate, task_queue, process_top from fooocusapi.models_v2 import * from fooocusapi.img_utils import base64_to_stream, read_input_image @@ -58,15 +58,19 @@ } -def call_worker(req: Text2ImgRequest, accept: str): - task_type = TaskType.text_2_img +def get_task_type(req: Text2ImgRequest) -> TaskType: if isinstance(req, ImgUpscaleOrVaryRequest) or isinstance(req, ImgUpscaleOrVaryRequestJson): - task_type = TaskType.img_uov + return TaskType.img_uov elif isinstance(req, ImgPromptRequest) or isinstance(req, ImgPromptRequestJson): - task_type = TaskType.img_prompt + return TaskType.img_prompt elif isinstance(req, ImgInpaintOrOutpaintRequest) or isinstance(req, ImgInpaintOrOutpaintRequestJson): - task_type = TaskType.img_inpaint_outpaint + return TaskType.img_inpaint_outpaint + else: + return TaskType.text_2_img + +def call_worker(req: Text2ImgRequest, accept: str) -> Tuple[QueueTask | None, List[ImageGenerationResult] | None]: + task_type = get_task_type(req) params = req_to_params(req) queue_task = task_queue.add_task( task_type, {'params': params.__dict__, 'accept': accept, 'require_base64': req.require_base64}, @@ -74,15 +78,43 @@ def call_worker(req: Text2ImgRequest, accept: str): if queue_task is None: print("[Task Queue] The task queue has reached limit") - results = [ImageGenerationResult(im=None, seed=0, + return None, [ImageGenerationResult(im=None, seed='', finish_reason=GenerationFinishReason.queue_is_full)] elif req.async_process: work_executor.submit(process_generate, queue_task, params) - results = queue_task + return queue_task, None else: results = process_generate(queue_task, params) - - return results + return queue_task, results + + +def build_generation_response(req: Text2ImgRequest, + streaming_output: bool, + task: QueueTask | None, + results: List[ImageGenerationResult] | None) -> Response | AsyncJobResponse | List[GeneratedImageResult]: + if streaming_output: + return generate_streaming_output([] if results is None else results) + + job_result: List[GeneratedImageResult] = [] + if results is not None: + job_result = generate_image_result_output(results, req.require_base64) + + if task is None: + # add to worker queue failed + if req.async_process: + return AsyncJobResponse(job_id='', + job_type=get_task_type(req), + job_stage=AsyncJobStage.error, + job_progress=0, + job_status=None, + job_step_preview=None, + job_result=job_result) + return job_result + + if req.async_process: + return generate_async_output(task) + else: + return job_result def stop_worker(): @@ -112,8 +144,8 @@ def text2img_generation(req: Text2ImgRequest, accept: str = Header(None), else: streaming_output = False - results = call_worker(req, accept) - return generation_output(results, streaming_output, req.require_base64) + task, results = call_worker(req, accept) + return build_generation_response(req, streaming_output, task, results) @app.post("/v2/generation/text-to-image-with-ip", response_model=List[GeneratedImageResult] | AsyncJobResponse, responses=img_generate_responses) @@ -145,8 +177,8 @@ def text_to_img_with_ip(req: Text2ImgRequestWithPrompt, req.image_prompts = image_prompts_files - results = call_worker(req, accept) - return generation_output(results, streaming_output, req.require_base64) + task, results = call_worker(req, accept) + return build_generation_response(req, streaming_output, task, results) @app.post("/v1/generation/image-upscale-vary", response_model=List[GeneratedImageResult] | AsyncJobResponse, responses=img_generate_responses) @@ -163,8 +195,8 @@ def img_upscale_or_vary(input_image: UploadFile, req: ImgUpscaleOrVaryRequest = else: streaming_output = False - results = call_worker(req, accept) - return generation_output(results, streaming_output, req.require_base64) + task, results = call_worker(req, accept) + return build_generation_response(req, streaming_output, task, results) @app.post("/v2/generation/image-upscale-vary", response_model=List[GeneratedImageResult] | AsyncJobResponse, responses=img_generate_responses) @@ -195,8 +227,8 @@ def img_upscale_or_vary_v2(req: ImgUpscaleOrVaryRequestJson, image_prompts_files.append(default_image_promt) req.image_prompts = image_prompts_files - results = call_worker(req, accept) - return generation_output(results, streaming_output, req.require_base64) + task, results = call_worker(req, accept) + return build_generation_response(req, streaming_output, task, results) @app.post("/v1/generation/image-inpait-outpaint", response_model=List[GeneratedImageResult] | AsyncJobResponse, responses=img_generate_responses) @@ -213,8 +245,8 @@ def img_inpaint_or_outpaint(input_image: UploadFile, req: ImgInpaintOrOutpaintRe else: streaming_output = False - results = call_worker(req, accept) - return generation_output(results, streaming_output, req.require_base64) + task, results = call_worker(req, accept) + return build_generation_response(req, streaming_output, task, results) @app.post("/v2/generation/image-inpait-outpaint", response_model=List[GeneratedImageResult] | AsyncJobResponse, responses=img_generate_responses) @@ -247,8 +279,8 @@ def img_inpaint_or_outpaint_v2(req: ImgInpaintOrOutpaintRequestJson, image_prompts_files.append(default_image_promt) req.image_prompts = image_prompts_files - results = call_worker(req, accept) - return generation_output(results, streaming_output, req.require_base64) + task, results = call_worker(req, accept) + return build_generation_response(req, streaming_output, task, results) @app.post("/v1/generation/image-prompt", response_model=List[GeneratedImageResult] | AsyncJobResponse, responses=img_generate_responses) @@ -266,8 +298,8 @@ def img_prompt(cn_img1: Optional[UploadFile] = File(None), else: streaming_output = False - results = call_worker(req, accept) - return generation_output(results, streaming_output, req.require_base64) + task, results = call_worker(req, accept) + return build_generation_response(req, streaming_output, task, results) @app.post("/v2/generation/image-prompt", response_model=List[GeneratedImageResult] | AsyncJobResponse, responses=img_generate_responses) @@ -304,22 +336,22 @@ def img_prompt(req: ImgPromptRequestJson, req.image_prompts = image_prompts_files - results = call_worker(req, accept) - return generation_output(results, streaming_output, req.require_base64) + task, results = call_worker(req, accept) + return build_generation_response(req, streaming_output, task, results) @app.get("/v1/generation/query-job", response_model=AsyncJobResponse, description="Query async generation job") def query_job(req: QueryJobRequest = Depends()): queue_task = task_queue.get_task(req.job_id, True) if queue_task is None: - return JSONResponse(content=AsyncJobResponse(job_id="", - job_type="Not Found", - job_stage="ERROR", - job_progress=0, - job_status="Job not found"), status_code=404) - - return generation_output(queue_task, streaming_output=False, require_base64=False, - require_step_preivew=req.require_step_preivew) + result = AsyncJobResponse(job_id="", + job_type=TaskType.not_found, + job_stage=AsyncJobStage.error, + job_progress=0, + job_status="Job not found") + content = result.model_dump_json() + return Response(content=content, media_type='application/json', status_code=404) + return generate_async_output(queue_task) @app.get("/v1/generation/job-queue", response_model=JobQueueInfo, description="Query job queue info") diff --git a/fooocusapi/api_utils.py b/fooocusapi/api_utils.py index f3cf89c..1ae6888 100644 --- a/fooocusapi/api_utils.py +++ b/fooocusapi/api_utils.py @@ -159,54 +159,54 @@ def req_to_params(req: Text2ImgRequest) -> ImageGenerationParams: ) -def generation_output(results: QueueTask | List[ImageGenerationResult], streaming_output: bool, require_base64: bool, require_step_preivew: bool=False) -> Response | List[GeneratedImageResult] | AsyncJobResponse: - if isinstance(results, QueueTask): - task = results - job_stage = AsyncJobStage.running - job_result = None - if task.start_millis == 0: - job_stage = AsyncJobStage.waiting - if task.is_finished: - if task.finish_with_error: - job_stage = AsyncJobStage.error - else: - if task.task_result != None: - job_stage = AsyncJobStage.success - task_result_require_base64 = False - if 'require_base64' in task.req_param and task.req_param['require_base64']: - task_result_require_base64 = True - - job_result = generation_output(task.task_result, False, task_result_require_base64) - job_step_preview = None if not require_step_preivew else task.task_step_preview - return AsyncJobResponse(job_id=task.job_id, - job_type=task.type, - job_stage=job_stage, - job_progress=task.finish_progress, - job_status=task.task_status, - job_step_preview=job_step_preview, - job_result=job_result) - - if streaming_output: - if len(results) == 0: - return Response(status_code=500) - result = results[0] - if result.finish_reason == GenerationFinishReason.queue_is_full: - return Response(status_code=409, content=result.finish_reason.value) - elif result.finish_reason == GenerationFinishReason.user_cancel: - return Response(status_code=400, content=result.finish_reason.value) - elif result.finish_reason == GenerationFinishReason.error: - return Response(status_code=500, content=result.finish_reason.value) - - bytes = output_file_to_bytesimg(results[0].im) - return Response(bytes, media_type='image/png') - else: - results = [GeneratedImageResult( - base64=output_file_to_base64img( - item.im) if require_base64 else None, +def generate_async_output(task: QueueTask) -> AsyncJobResponse: + job_stage = AsyncJobStage.running + job_result = None + + if task.start_millis == 0: + job_stage = AsyncJobStage.waiting + + if task.is_finished: + if task.finish_with_error: + job_stage = AsyncJobStage.error + elif task.task_result != None: + job_stage = AsyncJobStage.success + task_result_require_base64 = False + if 'require_base64' in task.req_param and task.req_param['require_base64']: + task_result_require_base64 = True + + job_result = generate_image_result_output(task.task_result, task_result_require_base64) + return AsyncJobResponse(job_id=task.job_id, + job_type=task.type, + job_stage=job_stage, + job_progress=task.finish_progress, + job_status=task.task_status, + job_step_preview=task.task_step_preview, + job_result=job_result) + + +def generate_streaming_output(results: List[ImageGenerationResult]) -> Response: + if len(results) == 0: + return Response(status_code=500) + result = results[0] + if result.finish_reason == GenerationFinishReason.queue_is_full: + return Response(status_code=409, content=result.finish_reason.value) + elif result.finish_reason == GenerationFinishReason.user_cancel: + return Response(status_code=400, content=result.finish_reason.value) + elif result.finish_reason == GenerationFinishReason.error: + return Response(status_code=500, content=result.finish_reason.value) + + bytes = output_file_to_bytesimg(results[0].im) + return Response(bytes, media_type='image/png') + + +def generate_image_result_output(results: List[ImageGenerationResult], require_base64: bool) -> List[GeneratedImageResult]: + results = [GeneratedImageResult( + base64=output_file_to_base64img(item.im) if require_base64 else None, url=get_file_serve_url(item.im), seed=item.seed, finish_reason=item.finish_reason) for item in results] - return results + return results class QueueReachLimitException(Exception): diff --git a/fooocusapi/models.py b/fooocusapi/models.py index 728b87c..b88578d 100644 --- a/fooocusapi/models.py +++ b/fooocusapi/models.py @@ -421,7 +421,7 @@ class AsyncJobResponse(BaseModel): class JobQueueInfo(BaseModel): running_size: int = Field(description="The current running and waiting job count") finished_size: int = Field(description="Finished job cound (after auto clean)") - last_job_id: str = Field(description="Last submit generation job id") + last_job_id: str | None = Field(description="Last submit generation job id") # TODO May need more detail fields, will add later when someone need