Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Add vision language model support to CPU backend #3968

Merged
merged 11 commits into from
Apr 22, 2024
1 change: 1 addition & 0 deletions vllm/executor/cpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def _init_worker(self):
rank=0,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=True,
)
Expand Down
60 changes: 37 additions & 23 deletions vllm/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
Expand All @@ -28,6 +28,7 @@ def __init__(
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
*args,
Expand All @@ -37,6 +38,7 @@ def __init__(
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.lora_config = lora_config
self.vision_language_config = vision_language_config
self.load_config = load_config
self.is_driver_worker = is_driver_worker

Expand All @@ -57,13 +59,14 @@ def __init__(
self.model_config.dtype if model_config is not None else None)

def load_model(self) -> None:
self.model = get_model(model_config=self.model_config,
load_config=self.load_config,
device_config=self.device_config,
vision_language_config=None,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config)
self.model = get_model(
model_config=self.model_config,
load_config=self.load_config,
device_config=self.device_config,
vision_language_config=self.vision_language_config,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config)

def _prepare_prompt(
self,
Expand All @@ -74,6 +77,7 @@ def _prepare_prompt(
input_positions: List[int] = []
slot_mapping: List[int] = []
prompt_lens: List[int] = []
multi_modal_input_list: List[torch.Tensor] = []

for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
Expand All @@ -94,6 +98,10 @@ def _prepare_prompt(
# is always the first token in the sequence.
input_positions.extend(list(range(computed_len, prompt_len)))

if seq_group_metadata.multi_modal_data:
multi_modal_input_list.append(
seq_group_metadata.multi_modal_data.data)

# Compute the slot mapping.
block_table = seq_group_metadata.block_tables[seq_id]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
Expand All @@ -116,6 +124,15 @@ def _prepare_prompt(
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)

if multi_modal_input_list:
assert self.vision_language_config, (
"Multi-modal inputs are only supported by "
"vision language models.")
multi_modal_input = torch.cat(multi_modal_input_list,
dim=0).to(self.device)
else:
multi_modal_input = None

num_prompt_tokens = len(input_tokens)

input_tokens = torch.tensor(input_tokens,
Expand All @@ -142,12 +159,8 @@ def _prepare_prompt(
slot_mapping=slot_mapping,
kv_cache_dtype=self.kv_cache_dtype,
)
return (
input_tokens,
input_positions,
attn_metadata,
prompt_lens,
)
return (input_tokens, input_positions, attn_metadata, prompt_lens,
multi_modal_input)

def _prepare_decode(
self,
Expand Down Expand Up @@ -331,14 +344,16 @@ def prepare_input_tensors(
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata,
SamplingMetadata]:
multi_modal_input = None
if self.is_driver_worker:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, attn_metadata,
prompt_lens) = self._prepare_prompt(seq_group_metadata_list)
(input_tokens, input_positions, attn_metadata, prompt_lens,
multi_modal_input
) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions,
attn_metadata) = self._prepare_decode(seq_group_metadata_list)
Expand Down Expand Up @@ -371,20 +386,17 @@ def prepare_input_tensors(
perform_sampling=False,
)

return (
input_tokens,
input_positions,
attn_metadata,
sampling_metadata,
)
return (input_tokens, input_positions, attn_metadata,
sampling_metadata, multi_modal_input)

@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
kv_caches: List[torch.Tensor],
) -> Optional[SamplerOutput]:
(input_tokens, input_positions, attn_metadata, sampling_metadata
(input_tokens, input_positions, attn_metadata, sampling_metadata,
multi_modal_input
) = self.prepare_input_tensors(seq_group_metadata_list)

model_executable = self.model
Expand All @@ -394,6 +406,8 @@ def execute_model(
"kv_caches": kv_caches,
"attn_metadata": attn_metadata,
}
if self.vision_language_config:
execute_model_kwargs.update({"image_input": multi_modal_input})

hidden_states = model_executable(**execute_model_kwargs)

Expand Down
24 changes: 15 additions & 9 deletions vllm/worker/cpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig)
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
from vllm.distributed import (broadcast_tensor_dict,
ensure_model_parallel_initialized,
init_distributed_environment)
Expand Down Expand Up @@ -122,6 +123,7 @@ def __init__(
rank: int,
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
vision_language_config: Optional[VisionLanguageConfig] = None,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
) -> None:
Expand All @@ -135,21 +137,25 @@ def __init__(
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.vision_language_config = vision_language_config
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."

if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
self.model_runner = CPUModelRunner(model_config,
parallel_config,
scheduler_config,
device_config,
load_config=self.load_config,
lora_config=self.lora_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker)
self.model_runner = CPUModelRunner(
model_config,
parallel_config,
scheduler_config,
device_config,
load_config=self.load_config,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine = None
Expand Down
Loading