diff --git a/requirements.txt b/requirements.txt index 9e031af8..4e2b1b38 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ accelerate addict albumentations basicsr +clip-retrieval controlnet-aux diffusers einops diff --git a/search.py b/search.py new file mode 100644 index 00000000..aebfaf5f --- /dev/null +++ b/search.py @@ -0,0 +1,61 @@ +import requests + +from clip_retrieval.clip_client import ClipClient, Modality +from PIL import Image + +from utils import get_image_name, get_new_image_name, prompts + + +def download_image(img_url, img_path): + img_stream = requests.get(img_url, stream=True) + if img_stream.status_code == 200: + img = Image.open(img_stream.raw) + img.save(img_path, format="png") + return img_path + + +def download_best_available(search_result, result_img_path): + if search_result: + img_path = download_image(search_result[0]["url"], result_img_path) + return img_path if img_path else download_best_available(search_result[1:], result_img_path) + + +class SearchSupport: + def __init__(self): + self.client = ClipClient( + url="https://knn.laion.ai/knn-service", + indice_name="laion5B-L-14", + modality=Modality.IMAGE, + aesthetic_score=0, + aesthetic_weight=0.0, + num_images=10, + ) + + +class ImageSearch(SearchSupport): + def __init__(self, *args, **kwargs): + print("Initializing Image Search") + super().__init__() + + @prompts(name="Search Image That Matches User Input Text", + description="useful when you want to search an image that matches a given description. " + "like: find an image that contains certain objects with certain properties, " + "or refine a previous search with additional criteria. " + "The input to this tool should be a string, representing the description. ") + def inference(self, query_text): + search_result = self.client.query(text=query_text) + return download_best_available(search_result, get_image_name()) + + +class VisualSearch(SearchSupport): + def __init__(self, *args, **kwargs): + print("Initializing Visual Search") + super().__init__() + + @prompts(name="Search Image Visually Similar to an Input Image", + description="useful when you want to search an image that is visually similar to an input image. " + "like: find an image visually similar to a generated or modified image. " + "The input to this tool should be a string, representing the input image path. ") + def inference(self, query_img_path): + search_result = self.client.query(image=query_img_path) + return download_best_available(search_result, get_new_image_name(query_img_path, "visual-search")) diff --git a/utils.py b/utils.py new file mode 100644 index 00000000..fe9a7945 --- /dev/null +++ b/utils.py @@ -0,0 +1,47 @@ +import os +import uuid + + +def prompts(name, description): + def decorator(func): + func.name = name + func.description = description + return func + + return decorator + + +def cut_dialogue_history(history_memory, keep_last_n_words=500): + if history_memory is None or len(history_memory) == 0: + return history_memory + tokens = history_memory.split() + n_tokens = len(tokens) + print(f"history_memory:{history_memory}, n_tokens: {n_tokens}") + if n_tokens < keep_last_n_words: + return history_memory + paragraphs = history_memory.split('\n') + last_n_tokens = n_tokens + while last_n_tokens >= keep_last_n_words: + last_n_tokens -= len(paragraphs[0].split(' ')) + paragraphs = paragraphs[1:] + return '\n' + '\n'.join(paragraphs) + + +def get_new_image_name(org_img_name, func_name="update"): + head_tail = os.path.split(org_img_name) + head = head_tail[0] + tail = head_tail[1] + name_split = tail.split('.')[0].split('_') + this_new_uuid = str(uuid.uuid4())[:4] + if len(name_split) == 1: + most_org_file_name = name_split[0] + else: + assert len(name_split) == 4 + most_org_file_name = name_split[3] + recent_prev_file_name = name_split[0] + new_file_name = f'{this_new_uuid}_{func_name}_{recent_prev_file_name}_{most_org_file_name}.png' + return os.path.join(head, new_file_name) + + +def get_image_name(): + return os.path.join('image', f"{str(uuid.uuid4())[:8]}.png") diff --git a/visual_chatgpt.py b/visual_chatgpt.py index d98dd7cc..db783bf6 100644 --- a/visual_chatgpt.py +++ b/visual_chatgpt.py @@ -4,8 +4,7 @@ import torch import cv2 import re -import uuid -from PIL import Image, ImageDraw, ImageOps +from PIL import Image, ImageOps import math import numpy as np import argparse @@ -25,6 +24,9 @@ from langchain.chains.conversation.memory import ConversationBufferMemory from langchain.llms.openai import OpenAI +from search import ImageSearch, VisualSearch +from utils import cut_dialogue_history, get_image_name, get_new_image_name, prompts + VISUAL_CHATGPT_PREFIX = """Visual ChatGPT is designed to be able to assist with a wide range of text and visual related tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. Visual ChatGPT is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand. Visual ChatGPT is able to process and understand large amounts of text and images. As a language model, Visual ChatGPT can not directly read images, but it has a list of tools to finish different visual tasks. Each image will have a file name formed as "image/xxx.png", and Visual ChatGPT can invoke different tools to indirectly understand pictures. When talking about images, Visual ChatGPT is very strict to the file name and will never fabricate nonexistent files. When using tools to generate new image files, Visual ChatGPT is also known that the image may not be the same as the user's demand, and will use other visual question answering tools or description tools to observe the real image. Visual ChatGPT is able to use tools in a sequence, and is loyal to the tool observation outputs rather than faking the image content and image file name. It will remember to provide the file name from the last tool observation, if a new image is generated. @@ -81,15 +83,6 @@ def seed_everything(seed): return seed -def prompts(name, description): - def decorator(func): - func.name = name - func.description = description - return func - - return decorator - - def blend_gt2pt(old_image, new_image, sigma=0.15, steps=100): new_size = new_image.size old_size = old_image.size @@ -147,39 +140,6 @@ def blend_gt2pt(old_image, new_image, sigma=0.15, steps=100): return gaussian_img -def cut_dialogue_history(history_memory, keep_last_n_words=500): - if history_memory is None or len(history_memory) == 0: - return history_memory - tokens = history_memory.split() - n_tokens = len(tokens) - print(f"history_memory:{history_memory}, n_tokens: {n_tokens}") - if n_tokens < keep_last_n_words: - return history_memory - paragraphs = history_memory.split('\n') - last_n_tokens = n_tokens - while last_n_tokens >= keep_last_n_words: - last_n_tokens -= len(paragraphs[0].split(' ')) - paragraphs = paragraphs[1:] - return '\n' + '\n'.join(paragraphs) - - -def get_new_image_name(org_img_name, func_name="update"): - head_tail = os.path.split(org_img_name) - head = head_tail[0] - tail = head_tail[1] - name_split = tail.split('.')[0].split('_') - this_new_uuid = str(uuid.uuid4())[:4] - if len(name_split) == 1: - most_org_file_name = name_split[0] - else: - assert len(name_split) == 4 - most_org_file_name = name_split[3] - recent_prev_file_name = name_split[0] - new_file_name = f'{this_new_uuid}_{func_name}_{recent_prev_file_name}_{most_org_file_name}.png' - return os.path.join(head, new_file_name) - - - class MaskFormer: def __init__(self, device): print(f"Initializing MaskFormer to {device}") @@ -295,7 +255,7 @@ def __init__(self, device): "like: generate an image of an object or something, or generate an image that includes some objects. " "The input to this tool should be a string, representing the text used to generate image. ") def inference(self, text): - image_filename = os.path.join('image', f"{str(uuid.uuid4())[:8]}.png") + image_filename = get_image_name() prompt = text + ', ' + self.a_prompt image = self.pipe(prompt, negative_prompt=self.n_prompt).images[0] image.save(image_filename) @@ -1021,7 +981,7 @@ def run_text(self, text, state): return state, state def run_image(self, image, state, txt): - image_filename = os.path.join('image', f"{str(uuid.uuid4())[:8]}.png") + image_filename = get_image_name() print("======>Auto Resize Image...") img = Image.open(image.name) width, height = img.size @@ -1046,11 +1006,13 @@ def run_image(self, image, state, txt): if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--load', type=str, default="ImageCaptioning_cuda:0,Text2Image_cuda:0") + parser.add_argument('--host', type=str, default="0.0.0.0") + parser.add_argument('--port', type=int, default=1015) args = parser.parse_args() load_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.load.split(',')} bot = ConversationBot(load_dict=load_dict) with gr.Blocks(css="#chatbot .overflow-y-auto{height:500px}") as demo: - chatbot = gr.Chatbot(elem_id="chatbot", label="Visual ChatGPT") + chatbot = gr.Chatbot(elem_id="chatbot", label="Visual ChatGPT").style(height=800) state = gr.State([]) with gr.Row(): with gr.Column(scale=0.7): @@ -1067,4 +1029,4 @@ def run_image(self, image, state, txt): clear.click(bot.memory.clear) clear.click(lambda: [], None, chatbot) clear.click(lambda: [], None, state) - demo.launch(server_name="0.0.0.0", server_port=1015) + demo.launch(server_name=args.host, server_port=args.port)