diff --git a/client/clip_client/client.py b/client/clip_client/client.py index 9a2b457c2..777ae9ebf 100644 --- a/client/clip_client/client.py +++ b/client/clip_client/client.py @@ -11,6 +11,7 @@ Generator, Iterable, Dict, + List, ) from urllib.parse import urlparse from functools import partial @@ -66,10 +67,10 @@ def __init__(self, server: str, credential: dict = {}, **kwargs): def encode( self, content: Iterable[str], + parameters: dict = {}, *, batch_size: Optional[int] = None, show_progress: bool = False, - parameters: Optional[dict] = None, ) -> 'np.ndarray': """Encode images and texts into embeddings where the input is an iterable of raw strings. Each image and text must be represented as a string. The following strings are acceptable: @@ -80,7 +81,8 @@ def encode( :param content: an iterator of image URIs or sentences, each element is an image or a text sentence as a string. :param batch_size: the number of elements in each request when sending ``content`` :param show_progress: if set, show a progress bar - :param parameters: the parameters for the encoding, you can specify the model to use when you have multiple models + :param parameters: the parameters for the encoding. Now we support: + - model: you can specify the model to use when you have multiple models :return: the embedding in a numpy ndarray with shape ``[N, D]``. ``N`` is in the same length of ``content`` """ ... @@ -89,21 +91,22 @@ def encode( def encode( self, content: Union['DocumentArray', Iterable['Document']], + parameters: dict = {}, *, batch_size: Optional[int] = None, show_progress: bool = False, - parameters: Optional[dict] = None, ) -> 'DocumentArray': """Encode images and texts into embeddings where the input is an iterable of :class:`docarray.Document`. :param content: an iterable of :class:`docarray.Document`, each Document must be filled with `.uri`, `.text` or `.blob`. :param batch_size: the number of elements in each request when sending ``content`` :param show_progress: if set, show a progress bar - :param parameters: the parameters for the encoding, you can specify the model to use when you have multiple models + :param parameters: the parameters for the encoding. Now we support: + - model: you can specify the model to use when you have multiple models :return: the embedding in a numpy ndarray with shape ``[N, D]``. ``N`` is in the same length of ``content`` """ ... - def encode(self, content, **kwargs): + def encode(self, content, parameters: dict = {}, **kwargs): if isinstance(content, str): raise TypeError( f'content must be an Iterable of [str, Document], try `.encode(["{content}"])` instead' @@ -114,10 +117,17 @@ def encode(self, content, **kwargs): total=len(content) if hasattr(content, '__len__') else None, ) results = DocumentArray() + + model_name = parameters.get('model', '') + post_args = self._get_post_args(content, kwargs) + with self._pbar: self._client.post( - **self._get_post_payload(content, kwargs), + on=f'/encode/{model_name}'.rstrip('/'), + inputs=self._iter_doc(content), + parameters=parameters, on_done=partial(self._gather_result, results=results), + **post_args, ) return self._unboxed_result(results) @@ -188,20 +198,16 @@ def _iter_doc(self, content) -> Generator['Document', None, None]: ), ) - def _get_post_payload(self, content, kwargs): - parameters = kwargs.get('parameters', {}) - model_name = parameters.get('model', '') - payload = dict( - on=f'/encode/{model_name}'.rstrip('/'), - inputs=self._iter_doc(content), + def _get_post_args(self, content, kwargs): + post_args = dict( request_size=kwargs.get('batch_size', 8), total_docs=len(content) if hasattr(content, '__len__') else None, ) if self._scheme == 'grpc' and self._authorization: - payload.update(metadata=('authorization', self._authorization)) + post_args.update(metadata=('authorization', self._authorization)) elif self._scheme == 'http' and self._authorization: - payload.update(headers={'Authorization': self._authorization}) - return payload + post_args.update(headers={'Authorization': self._authorization}) + return post_args def profile(self, content: Optional[str] = '') -> Dict[str, float]: """Profiling a single query's roundtrip including network and computation latency. Results is summarized in a table. @@ -259,6 +265,7 @@ def make_table(_title, _time, _percent): async def aencode( self, content: Iterator[str], + parameters: dict = {}, *, batch_size: Optional[int] = None, show_progress: bool = False, @@ -269,13 +276,14 @@ async def aencode( async def aencode( self, content: Union['DocumentArray', Iterable['Document']], + parameters: dict = {}, *, batch_size: Optional[int] = None, show_progress: bool = False, ) -> 'DocumentArray': ... - async def aencode(self, content, **kwargs): + async def aencode(self, content, parameters: dict = {}, **kwargs): from rich import filesize self._prepare_streaming( @@ -284,19 +292,29 @@ async def aencode(self, content, **kwargs): ) results = DocumentArray() - async for da in self._async_client.post( - **self._get_post_payload(content, kwargs) - ): - if not results: - self._pbar.start_task(self._r_task) - results.extend(da) - self._pbar.update( - self._r_task, - advance=len(da), - total_size=str( - filesize.decimal(int(os.environ.get('JINA_GRPC_RECV_BYTES', '0'))) - ), - ) + + model_name = parameters.get('model', '') + post_args = self._get_post_args(content, kwargs) + + with self._pbar: + async for da in self._async_client.post( + on=f'/encode/{model_name}'.rstrip('/'), + inputs=self._iter_doc(content), + parameters=parameters, + **post_args, + ): + if not results: + self._pbar.start_task(self._r_task) + results.extend(da) + self._pbar.update( + self._r_task, + advance=len(da), + total_size=str( + filesize.decimal( + int(os.environ.get('JINA_GRPC_RECV_BYTES', '0')) + ) + ), + ) return self._unboxed_result(results) @@ -369,30 +387,17 @@ def _iter_rank_docs( ), ) - def _get_rank_payload(self, content, kwargs): - parameters = kwargs.get('parameters', {}) - model_name = parameters.get('model', '') - payload = dict( - on=f'/rank/{model_name}'.rstrip('/'), - inputs=self._iter_rank_docs( - content, _source=kwargs.get('source', 'matches') - ), - request_size=kwargs.get('batch_size', 8), - total_docs=len(content) if hasattr(content, '__len__') else None, - ) - if self._scheme == 'grpc' and self._authorization: - payload.update(metadata=('authorization', self._authorization)) - elif self._scheme == 'http' and self._authorization: - payload.update(headers={'Authorization': self._authorization}) - return payload - - def rank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray': + def rank( + self, docs: Iterable['Document'], parameters: dict = {}, **kwargs + ) -> 'DocumentArray': """Rank image-text matches according to the server CLIP model. Given a Document with nested matches, where the root is image/text and the matches is in another modality, i.e. text/image; this method ranks the matches according to the CLIP model. Each match now has a new score inside ``clip_score`` and matches are sorted descendingly according to this score. More details can be found in: https://github.com/openai/CLIP#usage :param docs: the input Documents + :param parameters: parameters passed to rank function. Now we support: + - model: you can specify the model to use when you have multiple models :return: the ranked Documents in a DocumentArray. """ self._prepare_streaming( @@ -400,14 +405,25 @@ def rank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray': total=len(docs), ) results = DocumentArray() + + model_name = parameters.get('model', '') + post_args = self._get_post_args(docs, kwargs) + with self._pbar: self._client.post( - **self._get_rank_payload(docs, kwargs), + on=f'/rank/{model_name}'.rstrip('/'), + inputs=self._iter_rank_docs( + docs, _source=kwargs.get('source', 'matches') + ), + parameters=parameters, on_done=partial(self._gather_result, results=results), + **post_args, ) return results - async def arank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray': + async def arank( + self, docs: Iterable['Document'], parameters: dict = {}, **kwargs + ) -> 'DocumentArray': from rich import filesize self._prepare_streaming( @@ -415,16 +431,300 @@ async def arank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray': total=len(docs), ) results = DocumentArray() - async for da in self._async_client.post(**self._get_rank_payload(docs, kwargs)): - if not results: - self._pbar.start_task(self._r_task) - results.extend(da) - self._pbar.update( - self._r_task, - advance=len(da), - total_size=str( - filesize.decimal(int(os.environ.get('JINA_GRPC_RECV_BYTES', '0'))) + + model_name = parameters.get('model', '') + post_args = self._get_post_args(docs, kwargs) + + with self._pbar: + async for da in self._async_client.post( + on=f'/rank/{model_name}'.rstrip('/'), + inputs=self._iter_rank_docs( + docs, _source=kwargs.get('source', 'matches') ), + parameters=parameters, + **post_args, + ): + if not results: + self._pbar.start_task(self._r_task) + results.extend(da) + self._pbar.update( + self._r_task, + advance=len(da), + total_size=str( + filesize.decimal( + int(os.environ.get('JINA_GRPC_RECV_BYTES', '0')) + ) + ), + ) + + return results + + @overload + def index( + self, + content: Iterable[str], + *, + batch_size: Optional[int] = None, + show_progress: bool = False, + ): + """Index the embeddings created by server CLIP model. + Given the list of ``Document`` or strings, this function create an indexer which index + the embeddings. This will be used for top k search. ``AnnLiteIndexer`` is used + by default. + :param content: docs to be indexed. + :param batch_size: the number of elements in each request when sending ``content``. + :param show_progress: if set, show a progress bar. + """ + ... + + @overload + def index( + self, + content: Union['DocumentArray', Iterable['Document']], + *, + batch_size: Optional[int] = None, + show_progress: bool = False, + ): + """Index the embeddings created by server CLIP model. + Given the list of ``Document`` or strings, this function create an indexer which index + the embeddings. This will be used for top k search. ``AnnLiteIndexer`` is used + by default. + :param content: docs to be indexed. + :param batch_size: the number of elements in each request when sending ``content``. + :param show_progress: if set, show a progress bar. + """ + ... + + def index(self, content, **kwargs): + if isinstance(content, str): + raise TypeError( + f'content must be an Iterable of [str, Document], try `.encode(["{content}"])` instead' + ) + + self._prepare_streaming( + not kwargs.get('show_progress'), + total=len(content) if hasattr(content, '__len__') else None, + ) + results = DocumentArray() + post_args = self._get_post_args(content, kwargs) + + with self._pbar: + self._client.post( + on='/index', + inputs=self._iter_doc(content), + on_done=partial(self._gather_result, results=results), + **post_args, ) + @overload + async def aindex( + self, + content: Iterator[str], + *, + batch_size: Optional[int] = None, + show_progress: bool = False, + ): + ... + + @overload + async def aindex( + self, + content: Union['DocumentArray', Iterable['Document']], + *, + batch_size: Optional[int] = None, + show_progress: bool = False, + ): + ... + + async def aindex(self, content, **kwargs): + from rich import filesize + + self._prepare_streaming( + not kwargs.get('show_progress'), + total=len(content) if hasattr(content, '__len__') else None, + ) + results = DocumentArray() + post_args = self._get_post_args(content, kwargs) + + with self._pbar: + async for da in self._async_client.post( + on='/index', + inputs=self._iter_doc(content), + **post_args, + ): + if not results: + self._pbar.start_task(self._r_task) + results.extend(da) + self._pbar.update( + self._r_task, + advance=len(da), + total_size=str( + filesize.decimal( + int(os.environ.get('JINA_GRPC_RECV_BYTES', '0')) + ) + ), + ) + + @overload + def search( + self, + content: Iterable[str], + parameters: dict = {}, + *, + batch_size: Optional[int] = None, + show_progress: bool = False, + ) -> 'DocumentArray': + """Search for top k results for given query string or ``Document``. + If the input is a string, will use this string as query. If the input is a + ``Document``, will use this ``Document`` as query. + :param content: list of queries. + :param parameters: parameters passed to search function. Now we support: + - limit: int, return top limit results. Default is 10. + - filter: dict, apply filter when querying. Default is None. + - include_metadata: bool, whether return the document metadata in response. Default is True. + :param batch_size: the number of elements in each request when sending ``content``. + :param show_progress: if set, show a progress bar. + :return: top limit results. + """ + ... + + @overload + def search( + self, + content: Union['DocumentArray', Iterable['Document']], + parameters: dict = {}, + *, + batch_size: Optional[int] = None, + show_progress: bool = False, + ) -> 'DocumentArray': + """Search for top k results for given query string or ``Document``. + If the input is a string, will use this string as query. If the input is a + ``Document``, will use this ``Document`` as query. + :param content: list of queries. + :param parameters: parameters passed to search function. Now we support: + - limit: int, return top limit results. Default is 10. + - filter: dict, apply filter when querying. Default is None. + - include_metadata: bool, whether return the document metadata in response. Default is True. + :param batch_size: the number of elements in each request when sending ``content``. + :param show_progress: if set, show a progress bar. + :return: top limit results. + """ + ... + + def search(self, content, parameters: dict = {}, **kwargs) -> 'DocumentArray': + if isinstance(content, str): + raise TypeError( + f'content must be an Iterable of [str, Document], try `.encode(["{content}"])` instead' + ) + + self._prepare_streaming( + not kwargs.get('show_progress'), + total=len(content) if hasattr(content, '__len__') else None, + ) + results = DocumentArray() + + post_args = self._get_post_args(content, kwargs) + + with self._pbar: + self._client.post( + on='/search', + inputs=self._iter_doc(content), + parameters=parameters, + on_done=partial(self._gather_result, results=results), + **post_args, + ) return results + + @overload + async def asearch( + self, + content: Iterator[str], + parameters: dict = {}, + *, + limit: Optional[int] = 10, + batch_size: Optional[int] = None, + show_progress: bool = False, + ): + ... + + @overload + async def asearch( + self, + content: Union['DocumentArray', Iterable['Document']], + parameters: dict = {}, + *, + limit: Optional[int] = 10, + batch_size: Optional[int] = None, + show_progress: bool = False, + ): + ... + + async def asearch(self, content, parameters: dict = {}, **kwargs): + from rich import filesize + + self._prepare_streaming( + not kwargs.get('show_progress'), + total=len(content) if hasattr(content, '__len__') else None, + ) + results = DocumentArray() + + post_args = self._get_post_args(content, kwargs) + + with self._pbar: + async for da in self._async_client.post( + on='/search', + inputs=self._iter_doc(content), + parameters=parameters, + **post_args, + ): + if not results: + self._pbar.start_task(self._r_task) + results.extend(da) + self._pbar.update( + self._r_task, + advance=len(da), + total_size=str( + filesize.decimal( + int(os.environ.get('JINA_GRPC_RECV_BYTES', '0')) + ) + ), + ) + + def status(self): + return self._client.post(on='/status') + + def update(self, content, parameters: dict = {}, **kwargs): + self._prepare_streaming( + not kwargs.get('show_progress'), + total=len(content) if hasattr(content, '__len__') else None, + ) + + results = DocumentArray() + post_args = self._get_post_args(content, kwargs) + + with self._pbar: + return self._client.post( + on='/update', + inputs=self._iter_doc(content), + parameters=parameters, + on_done=partial(self._gather_result, results=results), + **post_args, + ) + + def delete(self, content, parameters: dict = {}, **kwargs): + self._prepare_streaming( + not kwargs.get('show_progress'), + total=len(content) if hasattr(content, '__len__') else None, + ) + + results = DocumentArray() + post_args = self._get_post_args(content, kwargs) + + with self._pbar: + return self._client.post( + on='/delete', + inputs=self._iter_doc(content), + parameters=parameters, + on_done=partial(self._gather_result, results=results), + **post_args, + ) diff --git a/scripts/streamlit.py b/scripts/streamlit.py new file mode 100644 index 000000000..b30e15f9c --- /dev/null +++ b/scripts/streamlit.py @@ -0,0 +1,47 @@ +import numpy as np +import pandas as pd +import streamlit as st +from docarray import DocumentArray, Document + +from client.clip_client.client import Client + +client = Client('grpc://0.0.0.0:61000') +st.title('Laion400M retrieval') + + +def display_results(results): + st.write('Search results for:', query) + cols = st.columns(2) + + for k, m in enumerate(results): + image = m.uri + col_id = 0 if k % 2 == 0 else 1 + + with cols[col_id]: + caption = m.text + score = m.scores['cosine'].value + st.markdown(f'#[{k + 1}] ({score:.3f}) {caption}') + cols[col_id].image(image) + + # data = [[r.text, st.image(), r.scores['cosine'].value] for r in results] + # df = pd.DataFrame( + # data, + # columns=('caption', 'uri', 'score')) + + # st.table(df) + + +def search(query): + res = client.search([query]) + + result = res[0].matches[:10] + display_results(result) + + +query = st.text_input('Query', 'Type your query here...') + +if st.button('search'): + message = 'Wait for it...' + + with st.spinner(message): + search(query) diff --git a/server/clip_server/retrieval-flow.yml b/server/clip_server/retrieval-flow.yml new file mode 100644 index 000000000..e75fc6e0e --- /dev/null +++ b/server/clip_server/retrieval-flow.yml @@ -0,0 +1,23 @@ +jtype: Flow +version: '1' +with: + port: 61000 +executors: + - name: encoder + uses: + jtype: CLIPEncoder + metas: + py_modules: + - clip_server.executors.clip_torch + replicas: 1 + - name: indexer + uses: + jtype: AnnLiteIndexer + with: + dim: 128 + data_path: workspace + metas: + py_modules: + - annlite.executor + shards: 5 + polling: {'/index': 'ANY', '/search': 'ALL', '/update': 'ALL', '/delete': 'ALL', '/status': 'ALL'}