Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft demo #778

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 90 additions & 28 deletions client/clip_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Generator,
Iterable,
Dict,
List,
)
from urllib.parse import urlparse
from functools import partial
Expand Down Expand Up @@ -61,6 +62,7 @@ def __init__(self, server: str, credential: dict = {}, **kwargs):
raise ValueError(f'{server} is not a valid scheme')

self._authorization = credential.get('Authorization', None)
self._batch_size = 128
jemmyshin marked this conversation as resolved.
Show resolved Hide resolved

@overload
def encode(
Expand Down Expand Up @@ -114,9 +116,16 @@ def encode(self, content, **kwargs):
total=len(content) if hasattr(content, '__len__') else None,
)
results = DocumentArray()

parameters = kwargs.get('parameters', {})
model_name = parameters.get('model', '')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raise error when parameters=None

payload = self._get_post_parameters(content, kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename it as params?

payload.update(on=f'/encode/{model_name}'.rstrip('/'))
payload.update(inputs=self._iter_doc(content))

with self._pbar:
self._client.post(
**self._get_post_payload(content, kwargs),
**payload,
on_done=partial(self._gather_result, results=results),
)
return self._unboxed_result(results)
Expand Down Expand Up @@ -188,12 +197,8 @@ def _iter_doc(self, content) -> Generator['Document', None, None]:
),
)

def _get_post_payload(self, content, kwargs):
parameters = kwargs.get('parameters', {})
model_name = parameters.get('model', '')
def _get_post_parameters(self, content, kwargs):
payload = dict(
on=f'/encode/{model_name}'.rstrip('/'),
inputs=self._iter_doc(content),
request_size=kwargs.get('batch_size', 8),
total_docs=len(content) if hasattr(content, '__len__') else None,
)
Expand Down Expand Up @@ -284,9 +289,14 @@ async def aencode(self, content, **kwargs):
)

results = DocumentArray()
async for da in self._async_client.post(
**self._get_post_payload(content, kwargs)
):

parameters = kwargs.get('parameters', {})
jemmyshin marked this conversation as resolved.
Show resolved Hide resolved
model_name = parameters.get('model', '')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parameters can be None, and will raise error here. Please fix it

payload = self._get_post_parameters(content, kwargs)
payload.update(on=f'/encode/{model_name}'.rstrip('/'))
payload.update(inputs=self._iter_doc(content))

jemmyshin marked this conversation as resolved.
Show resolved Hide resolved
async for da in self._async_client.post(**payload):
if not results:
self._pbar.start_task(self._r_task)
results.extend(da)
Expand Down Expand Up @@ -369,23 +379,6 @@ def _iter_rank_docs(
),
)

def _get_rank_payload(self, content, kwargs):
parameters = kwargs.get('parameters', {})
model_name = parameters.get('model', '')
payload = dict(
on=f'/rank/{model_name}'.rstrip('/'),
inputs=self._iter_rank_docs(
content, _source=kwargs.get('source', 'matches')
),
request_size=kwargs.get('batch_size', 8),
total_docs=len(content) if hasattr(content, '__len__') else None,
)
if self._scheme == 'grpc' and self._authorization:
payload.update(metadata=('authorization', self._authorization))
elif self._scheme == 'http' and self._authorization:
payload.update(headers={'Authorization': self._authorization})
return payload

def rank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray':
"""Rank image-text matches according to the server CLIP model.
Given a Document with nested matches, where the root is image/text and the matches is in another modality, i.e.
Expand All @@ -400,9 +393,18 @@ def rank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray':
total=len(docs),
)
results = DocumentArray()

parameters = kwargs.get('parameters', {})
model_name = parameters.get('model', '')
payload = self._get_post_parameters(docs, kwargs)
payload.update(on=f'/rank/{model_name}'.rstrip('/'))
payload.update(
inputs=self._iter_rank_docs(docs, _source=kwargs.get('source', 'matches'))
)

with self._pbar:
self._client.post(
**self._get_rank_payload(docs, kwargs),
**payload,
on_done=partial(self._gather_result, results=results),
)
return results
Expand All @@ -415,7 +417,16 @@ async def arank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray':
total=len(docs),
)
results = DocumentArray()
async for da in self._async_client.post(**self._get_rank_payload(docs, kwargs)):

parameters = kwargs.get('parameters', {})
model_name = parameters.get('model', '')
payload = self._get_post_parameters(docs, kwargs)
payload.update(on=f'/rank/{model_name}'.rstrip('/'))
payload.update(
inputs=self._iter_rank_docs(docs, _source=kwargs.get('source', 'matches'))
)

async for da in self._async_client.post(**payload):
if not results:
self._pbar.start_task(self._r_task)
results.extend(da)
Expand All @@ -428,3 +439,54 @@ async def arank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray':
)

return results

def index(self, content: Iterable['Document'], **kwargs):
"""Index the embeddings created by server CLIP model.
Given the document with embeddings, this function create an indexer which index
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstr is not corrct. index and query functions both accept raw text/image docs without embeddings.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's more, we also implment async version, aindex, aquery

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The index should have the same function signature as encode

the embeddings. This will be used for top k search. ``AnnLiteIndexer`` is used
by default.
:param content: docs to be indexed.
"""
self._prepare_streaming(
not kwargs.get('show_progress'),
total=len(content) if hasattr(content, '__len__') else None,
)

payload = self._get_post_parameters(content, kwargs)
payload.update(on='/index')
jemmyshin marked this conversation as resolved.
Show resolved Hide resolved
payload.update(inputs=self._iter_doc(content))

with self._pbar:
self._client.post(
**payload,
)

def search(self, content: List[str], **kwargs) -> DocumentArray:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

search should also accept either str and Document

"""Search for top k results for given query string or ``Document``.
If the input is a string, will use this string as query. If the input is a
jemmyshin marked this conversation as resolved.
Show resolved Hide resolved
``Document``, will use the ``text`` field as query.
:param content: list of queries.
:return: top limit results.
"""
self._prepare_streaming(
not kwargs.get('show_progress'),
total=len(content) if hasattr(content, '__len__') else None,
)
results = DocumentArray()

payload = self._get_post_parameters(content, kwargs)
payload.update(on='/search')
payload.update(inputs=self._iter_doc(content))

with self._pbar:
self._client.post(
**payload,
on_done=partial(self._gather_result, results=results),
)
return results

def status(self):
payload = dict(
on='/status',
)
return self._client.post(**payload)
21 changes: 21 additions & 0 deletions server/clip_server/demo-flow.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
jtype: Flow
version: '1'
with:
port: 61000
executors:
- name: encoder
uses:
jtype: CLIPEncoder
metas:
py_modules:
- clip_server.executors.clip_torch
replicas: 1
- name: indexer
uses:
jtype: AnnLiteIndexer
with:
dim: 512
metas:
py_modules:
- annlite.executor
replicas: 1