diff --git a/server/clip_server/model/mclip_model.py b/server/clip_server/model/mclip_model.py index 6e59653c3..87a25a4e6 100644 --- a/server/clip_server/model/mclip_model.py +++ b/server/clip_server/model/mclip_model.py @@ -73,5 +73,5 @@ def encode_text( input_ids=input_ids, attention_mask=attention_mask, **kwargs ) - def encode_image(self, pixel_values: torch.Tensor, **kwargs): + def encode_image(self, pixel_values: torch.Tensor): return self._model.encode_image(pixel_values) diff --git a/server/clip_server/model/openclip_model.py b/server/clip_server/model/openclip_model.py index c496331c4..ba436c10c 100644 --- a/server/clip_server/model/openclip_model.py +++ b/server/clip_server/model/openclip_model.py @@ -6,11 +6,17 @@ # Ludwig Schmidt from typing import TYPE_CHECKING +from copy import deepcopy +import torch from clip_server.model.clip_model import CLIPModel from clip_server.model.pretrained_models import get_model_url_md5, download_model -import open_clip -from open_clip.openai import load_openai_model + +from open_clip.model import ( + CLIP, + convert_weights_to_fp16, +) +from open_clip.factory import _MODEL_CONFIGS, load_state_dict, load_openai_model if TYPE_CHECKING: import torch @@ -20,28 +26,46 @@ class OpenCLIPModel(CLIPModel): def __init__(self, name: str, device: str = 'cpu', jit: bool = False, **kwargs): super().__init__(name, **kwargs) + if '::' in name: + model_name, pretrained = name.split('::') + else: + # default pretrained model is from openai + model_name = name + pretrained = 'openai' + + self._model_name = model_name + model_url, md5sum = get_model_url_md5(name) - if model_url: - model_path = download_model(model_url, md5sum=md5sum) + model_path = download_model(model_url, md5sum=md5sum) + if pretrained.lower() == 'openai': self._model = load_openai_model(model_path, device=device, jit=jit) - self._model_name = name.split('::')[0] else: - model_name, pretrained = name.split('::') - self._model = open_clip.create_model( - model_name, pretrained=pretrained, device=device, jit=jit - ) - self._model_name = model_name + if model_name in _MODEL_CONFIGS: + model_cfg = deepcopy(_MODEL_CONFIGS[model_name]) + else: + raise RuntimeError(f'Model config for {model_name} not found.') + + self._model = CLIP(**model_cfg) + + state_dict = load_state_dict(model_path) + self._model.load_state_dict(state_dict, strict=True) + + if str(device) == 'cuda': + convert_weights_to_fp16(self._model) + if jit: + self._model = torch.jit.script(self._model) + + self._model.to(device=torch.device(device)) + self._model.eval() @property def model_name(self): if self._model_name == 'ViT-L/14@336px': return 'ViT-L-14-336' - elif self._model_name.endswith('-quickgelu'): - return self._model_name[:-10] return self._model_name.replace('/', '-') def encode_text(self, input_ids: 'torch.Tensor', **kwargs): return self._model.encode_text(input_ids) - def encode_image(self, pixel_values: 'torch.Tensor', **kwargs): + def encode_image(self, pixel_values: 'torch.Tensor'): return self._model.encode_image(pixel_values) diff --git a/server/clip_server/model/pretrained_models.py b/server/clip_server/model/pretrained_models.py index ba6866e00..949bfec81 100644 --- a/server/clip_server/model/pretrained_models.py +++ b/server/clip_server/model/pretrained_models.py @@ -82,7 +82,6 @@ 'ViT-B-32': 224, 'ViT-B-16': 224, 'ViT-B-16-plus-240': 240, - 'ViT-B-16-plus-240': 240, 'ViT-L-14': 224, 'ViT-L-14-336': 336, 'Vit-B-16Plus': 240, diff --git a/server/setup.py b/server/setup.py index 3dbade3bc..90117cfb6 100644 --- a/server/setup.py +++ b/server/setup.py @@ -49,7 +49,7 @@ 'torchvision', 'jina>=3.6.0', 'prometheus-client', - 'open_clip_torch', + 'open_clip_torch>=1.3.0', ], extras_require={ 'onnx': [