Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: update openclip loader #782

Merged
merged 7 commits into from
Jul 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion server/clip_server/model/mclip_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,5 @@ def encode_text(
input_ids=input_ids, attention_mask=attention_mask, **kwargs
)

def encode_image(self, pixel_values: torch.Tensor, **kwargs):
def encode_image(self, pixel_values: torch.Tensor):
return self._model.encode_image(pixel_values)
50 changes: 37 additions & 13 deletions server/clip_server/model/openclip_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,17 @@
# Ludwig Schmidt

from typing import TYPE_CHECKING
from copy import deepcopy
import torch

from clip_server.model.clip_model import CLIPModel
from clip_server.model.pretrained_models import get_model_url_md5, download_model
import open_clip
from open_clip.openai import load_openai_model

from open_clip.model import (
CLIP,
convert_weights_to_fp16,
)
from open_clip.factory import _MODEL_CONFIGS, load_state_dict, load_openai_model

if TYPE_CHECKING:
import torch
Expand All @@ -20,28 +26,46 @@ class OpenCLIPModel(CLIPModel):
def __init__(self, name: str, device: str = 'cpu', jit: bool = False, **kwargs):
super().__init__(name, **kwargs)

if '::' in name:
model_name, pretrained = name.split('::')
else:
# default pretrained model is from openai
model_name = name
pretrained = 'openai'
numb3r3 marked this conversation as resolved.
Show resolved Hide resolved

self._model_name = model_name

model_url, md5sum = get_model_url_md5(name)
if model_url:
model_path = download_model(model_url, md5sum=md5sum)
model_path = download_model(model_url, md5sum=md5sum)
if pretrained.lower() == 'openai':
self._model = load_openai_model(model_path, device=device, jit=jit)
self._model_name = name.split('::')[0]
else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should make sure all of the models are uploaded to our s3 bucket.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i uploaded pt models, should be available in any minutes

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update: any hours 😅poor internet

model_name, pretrained = name.split('::')
self._model = open_clip.create_model(
model_name, pretrained=pretrained, device=device, jit=jit
)
self._model_name = model_name
if model_name in _MODEL_CONFIGS:
model_cfg = deepcopy(_MODEL_CONFIGS[model_name])
else:
raise RuntimeError(f'Model config for {model_name} not found.')

self._model = CLIP(**model_cfg)

state_dict = load_state_dict(model_path)
self._model.load_state_dict(state_dict, strict=True)

if str(device) == 'cuda':
convert_weights_to_fp16(self._model)
if jit:
self._model = torch.jit.script(self._model)

self._model.to(device=torch.device(device))
self._model.eval()

@property
def model_name(self):
if self._model_name == 'ViT-L/14@336px':
return 'ViT-L-14-336'
elif self._model_name.endswith('-quickgelu'):
return self._model_name[:-10]
return self._model_name.replace('/', '-')

def encode_text(self, input_ids: 'torch.Tensor', **kwargs):
return self._model.encode_text(input_ids)

def encode_image(self, pixel_values: 'torch.Tensor', **kwargs):
def encode_image(self, pixel_values: 'torch.Tensor'):
return self._model.encode_image(pixel_values)
1 change: 0 additions & 1 deletion server/clip_server/model/pretrained_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@
'ViT-B-32': 224,
'ViT-B-16': 224,
'ViT-B-16-plus-240': 240,
'ViT-B-16-plus-240': 240,
'ViT-L-14': 224,
'ViT-L-14-336': 336,
'Vit-B-16Plus': 240,
Expand Down
2 changes: 1 addition & 1 deletion server/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
'torchvision',
'jina>=3.6.0',
'prometheus-client',
'open_clip_torch',
'open_clip_torch>=1.3.0',
],
extras_require={
'onnx': [
Expand Down