diff --git a/.gitmodules b/.gitmodules index 7edf097..5cc3e08 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ [submodule "vendor/llama.cpp"] path = vendor/llama.cpp - url = https://github.com/ggerganov/llama.cpp.git + url = http://github.com/inference-sh/llama.cpp diff --git a/CHANGELOG.md b/CHANGELOG.md index 605370e..affbd5d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.3.9] + +- feat: Update llama.cpp to ggerganov/llama.cpp@8733e0cf6eefc7c7752297cc22d0836706f4222c + ## [0.3.8] - feat: Update llama.cpp to ggerganov/llama.cpp@7841fc723e059d1fd9640e5c0ef19050fcc7c698 diff --git a/CMakeLists.txt b/CMakeLists.txt index 64a0304..47073ea 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,7 +3,7 @@ cmake_minimum_required(VERSION 3.21) project(llama_cpp) option(LLAMA_BUILD "Build llama.cpp shared library and install alongside python package" ON) -option(LLAVA_BUILD "Build llava shared library and install alongside python package" ON) +option(MTMD_BUILD "Build multimodal (mtmd) shared library and install alongside python package" ON) function(llama_cpp_python_install_target target) if(NOT TARGET ${target}) @@ -62,6 +62,9 @@ if (LLAMA_BUILD) # Enable building of the common library set(LLAMA_BUILD_COMMON ON CACHE BOOL "Build llama.cpp common library" FORCE) + # Disable building curl support + set(LLAMA_CURL OFF CACHE BOOL "llama.cpp: enable curl" FORCE) + # Architecture detection and settings for Apple platforms if (APPLE) # Get the target architecture @@ -132,7 +135,7 @@ if (LLAMA_BUILD) ) endif() - if (LLAVA_BUILD) + if (MTMD_BUILD) if (LLAMA_CUBLAS OR LLAMA_CUDA) add_compile_definitions(GGML_USE_CUBLAS) add_compile_definitions(GGML_USE_CUDA) @@ -142,36 +145,26 @@ if (LLAMA_BUILD) add_compile_definitions(GGML_USE_METAL) endif() - # Building llava - add_subdirectory(vendor/llama.cpp/examples/llava) - set_target_properties(llava_shared PROPERTIES OUTPUT_NAME "llava") + # Building multimodal support using mtmd + add_subdirectory(vendor/llama.cpp/tools/mtmd) if (WIN32) - set_target_properties(llava_shared PROPERTIES CUDA_ARCHITECTURES OFF) + set_target_properties(mtmd PROPERTIES CUDA_ARCHITECTURES OFF) endif() - llama_cpp_python_install_target(llava_shared) + llama_cpp_python_install_target(mtmd) if (WIN32) install( - FILES $ + FILES $ DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/llama_cpp/lib ) install( - FILES $ + FILES $ DESTINATION ${SKBUILD_PLATLIB_DIR}/llama_cpp/lib ) endif() - # Fix for llava build: Add include directory for llama.h - # Move these commands after the add_subdirectory call - target_include_directories(llava PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/vendor/llama.cpp/include) - target_include_directories(llava PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/vendor/llama.cpp/ggml/include) - - if (BUILD_SHARED_LIBS) - target_include_directories(llava_shared PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/vendor/llama.cpp/include) - target_include_directories(llava_shared PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/vendor/llama.cpp/ggml/include) - endif() - - target_include_directories(llama-llava-cli PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/vendor/llama.cpp/include) - target_include_directories(llama-minicpmv-cli PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/vendor/llama.cpp/include) + # Add include directories for mtmd + target_include_directories(mtmd PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/vendor/llama.cpp/include) + target_include_directories(mtmd PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/vendor/llama.cpp/ggml/include) endif() endif() diff --git a/examples/notebooks/Batching.ipynb b/examples/notebooks/Batching.ipynb index be7fe9b..b1992e9 100644 --- a/examples/notebooks/Batching.ipynb +++ b/examples/notebooks/Batching.ipynb @@ -230,7 +230,7 @@ "outputs": [], "source": [ "for i in range(n_parallel):\n", - " llama_cpp.llama_kv_cache_seq_cp(ctx, 0, i, 0, batch.n_tokens)" + " llama_cpp.llama_kv_self_seq_cp(ctx, 0, i, 0, batch.n_tokens)" ] }, { diff --git a/llama_cpp/__init__.py b/llama_cpp/__init__.py index b1a8b9b..2c9c527 100644 --- a/llama_cpp/__init__.py +++ b/llama_cpp/__init__.py @@ -1,4 +1,4 @@ from .llama_cpp import * from .llama import * -__version__ = "0.3.8" +__version__ = "0.3.9" diff --git a/llama_cpp/_ctypes_extensions.py b/llama_cpp/_ctypes_extensions.py index e88ed38..032e983 100644 --- a/llama_cpp/_ctypes_extensions.py +++ b/llama_cpp/_ctypes_extensions.py @@ -128,4 +128,4 @@ def _byref(obj: CtypesCData, offset: Optional[int] = None) -> CtypesRef[CtypesCD ... -byref = _byref if TYPE_CHECKING else ctypes.byref +byref = _byref if TYPE_CHECKING else ctypes.byref \ No newline at end of file diff --git a/llama_cpp/_internals.py b/llama_cpp/_internals.py index 343581d..e0b8548 100644 --- a/llama_cpp/_internals.py +++ b/llama_cpp/_internals.py @@ -2,6 +2,7 @@ import os import ctypes +from enum import Enum from typing import ( Dict, @@ -24,7 +25,13 @@ # Python wrappers over llama.h structs - +class LlamaBackendDev(Enum): + # CPU device using system memory + CPU = 0 + # GPU device using dedicated memory + GPU = 1 + # accelerator devices intended to be used together with the CPU backend (e.g. BLAS or AMX) + ACCEL = 2 class LlamaModel: """Intermediate Python wrapper for a llama.cpp llama_model. @@ -88,6 +95,12 @@ def n_ctx_train(self) -> int: def n_embd(self) -> int: return llama_cpp.llama_n_embd(self.model) + + def n_layer(self) -> int: + return llama_cpp.llama_n_layer(self.model) + + def dev_layer(self, il: int) -> LlamaBackendDev: + return LlamaBackendDev(llama_cpp.llama_model_dev_layer(self.model, il)) def rope_freq_scale_train(self) -> float: return llama_cpp.llama_model_rope_freq_scale_train(self.model) @@ -276,20 +289,20 @@ def n_ctx(self) -> int: def pooling_type(self) -> int: return llama_cpp.llama_pooling_type(self.ctx) - def kv_cache_clear(self): - llama_cpp.llama_kv_cache_clear(self.ctx) + def kv_self_clear(self): + llama_cpp.llama_kv_self_clear(self.ctx) - def kv_cache_seq_rm(self, seq_id: int, p0: int, p1: int): - llama_cpp.llama_kv_cache_seq_rm(self.ctx, seq_id, p0, p1) + def kv_self_seq_rm(self, seq_id: int, p0: int, p1: int): + llama_cpp.llama_kv_self_seq_rm(self.ctx, seq_id, p0, p1) - def kv_cache_seq_cp(self, seq_id_src: int, seq_id_dst: int, p0: int, p1: int): - llama_cpp.llama_kv_cache_seq_cp(self.ctx, seq_id_src, seq_id_dst, p0, p1) + def kv_self_seq_cp(self, seq_id_src: int, seq_id_dst: int, p0: int, p1: int): + llama_cpp.llama_kv_self_seq_cp(self.ctx, seq_id_src, seq_id_dst, p0, p1) - def kv_cache_seq_keep(self, seq_id: int): - llama_cpp.llama_kv_cache_seq_keep(self.ctx, seq_id) + def kv_self_seq_keep(self, seq_id: int): + llama_cpp.llama_kv_self_seq_keep(self.ctx, seq_id) - def kv_cache_seq_shift(self, seq_id: int, p0: int, p1: int, shift: int): - llama_cpp.llama_kv_cache_seq_add(self.ctx, seq_id, p0, p1, shift) + def kv_self_seq_shift(self, seq_id: int, p0: int, p1: int, shift: int): + llama_cpp.llama_kv_self_seq_add(self.ctx, seq_id, p0, p1, shift) def get_state_size(self) -> int: return llama_cpp.llama_get_state_size(self.ctx) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 7e9a6af..93e8e56 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -39,6 +39,7 @@ LlamaRAMCache, # type: ignore ) from .llama_tokenizer import BaseLlamaTokenizer, LlamaTokenizer +from .llama_cpp import GGMLType import llama_cpp.llama_cpp as llama_cpp import llama_cpp.llama_chat_format as llama_chat_format @@ -110,8 +111,8 @@ def __init__( # Tokenizer Override tokenizer: Optional[BaseLlamaTokenizer] = None, # KV cache quantization - type_k: Optional[int] = None, - type_v: Optional[int] = None, + type_k: Optional[GGMLType] = None, + type_v: Optional[GGMLType] = None, # Misc spm_infill: bool = False, verbose: bool = True, @@ -302,60 +303,9 @@ def __init__( ].key = b"\0" # ensure sentinel element is zeroed self.model_params.kv_overrides = self._kv_overrides_array - self.n_batch = min(n_ctx, n_batch) # ??? - self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1) - self.n_threads_batch = n_threads_batch or multiprocessing.cpu_count() - # Used by the sampler self._seed = seed or llama_cpp.LLAMA_DEFAULT_SEED - # Context Params - self.context_params = llama_cpp.llama_context_default_params() - self.context_params.n_ctx = n_ctx - self.context_params.n_batch = self.n_batch - self.context_params.n_ubatch = min(self.n_batch, n_ubatch) - self.context_params.n_threads = self.n_threads - self.context_params.n_threads_batch = self.n_threads_batch - self.context_params.rope_scaling_type = ( - rope_scaling_type - if rope_scaling_type is not None - else llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED - ) - self.context_params.pooling_type = pooling_type - self.context_params.rope_freq_base = ( - rope_freq_base if rope_freq_base != 0.0 else 0 - ) - self.context_params.rope_freq_scale = ( - rope_freq_scale if rope_freq_scale != 0.0 else 0 - ) - self.context_params.yarn_ext_factor = ( - yarn_ext_factor if yarn_ext_factor != 0.0 else 0 - ) - self.context_params.yarn_attn_factor = ( - yarn_attn_factor if yarn_attn_factor != 0.0 else 0 - ) - self.context_params.yarn_beta_fast = ( - yarn_beta_fast if yarn_beta_fast != 0.0 else 0 - ) - self.context_params.yarn_beta_slow = ( - yarn_beta_slow if yarn_beta_slow != 0.0 else 0 - ) - self.context_params.yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0 - self.context_params.logits_all = ( - logits_all if draft_model is None else True - ) # Must be set to True for speculative decoding - self.context_params.embeddings = embedding # TODO: Rename to embeddings - self.context_params.offload_kqv = offload_kqv - self.context_params.flash_attn = flash_attn - # KV cache quantization - if type_k is not None: - self.context_params.type_k = type_k - if type_v is not None: - self.context_params.type_v = type_v - # Sampling Params - self.context_params.no_perf = no_perf - self.last_n_tokens_size = last_n_tokens_size - self.cache: Optional[BaseLlamaCache] = None self.lora_base = lora_base @@ -376,39 +326,45 @@ def __init__( ) ) ) - - # Override tokenizer + + self.draft_model = draft_model + + # Override tokenizer self.tokenizer_ = tokenizer or LlamaTokenizer(self) + + self._n_vocab = self.n_vocab() - # Set the default value for the context and correct the batch - if n_ctx == 0: - n_ctx = self._model.n_ctx_train() - self.n_batch = min(n_ctx, n_batch) - self.context_params.n_ctx = self._model.n_ctx_train() - self.context_params.n_batch = self.n_batch - self.context_params.n_ubatch = min(self.n_batch, n_ubatch) - - self._ctx = self._stack.enter_context( - contextlib.closing( - internals.LlamaContext( - model=self._model, - params=self.context_params, - verbose=self.verbose, - ) - ) - ) + self._token_nl = self.token_nl() + self._token_eos = self.token_eos() - self._batch = self._stack.enter_context( - contextlib.closing( - internals.LlamaBatch( - n_tokens=self.n_batch, - embd=0, - n_seq_max=self.context_params.n_ctx, - verbose=self.verbose, - ) - ) + self._candidates = internals.LlamaTokenDataArray(n_vocab=self._n_vocab) + # Context Params + self._create_context( + n_ctx=n_ctx, + n_batch=n_batch, + n_ubatch=min(n_batch, n_ubatch), + n_threads=n_threads, + n_threads_batch=n_threads_batch, + rope_scaling_type=rope_scaling_type, + pooling_type=pooling_type, + rope_freq_base=rope_freq_base, + rope_freq_scale=rope_freq_scale, + yarn_ext_factor=yarn_ext_factor, + yarn_attn_factor=yarn_attn_factor, + yarn_beta_fast=yarn_beta_fast, + yarn_beta_slow=yarn_beta_slow, + yarn_orig_ctx=yarn_orig_ctx, + logits_all=logits_all, + embedding=embedding, + offload_kqv=offload_kqv, + flash_attn=flash_attn, + no_perf=no_perf, + last_n_tokens_size=last_n_tokens_size, + type_k=type_k, + type_v=type_v, ) + self._lora_adapter: Optional[llama_cpp.llama_adapter_lora_p] = None if self.lora_path: @@ -445,22 +401,6 @@ def free_lora_adapter(): str, llama_chat_format.LlamaChatCompletionHandler ] = {} - self.draft_model = draft_model - - self._n_vocab = self.n_vocab() - self._n_ctx = self.n_ctx() - - self._token_nl = self.token_nl() - self._token_eos = self.token_eos() - - self._candidates = internals.LlamaTokenDataArray(n_vocab=self._n_vocab) - - self.n_tokens = 0 - self.input_ids: npt.NDArray[np.intc] = np.ndarray((n_ctx,), dtype=np.intc) - self.scores: npt.NDArray[np.single] = np.ndarray( - (n_ctx if logits_all == True else n_batch, self._n_vocab), dtype=np.single - ) - self._mirostat_mu = ctypes.c_float( 2.0 * 5.0 ) # TODO: Move this to sampling context @@ -542,6 +482,10 @@ def free_lora_adapter(): f"Using fallback chat format: {self.chat_format}", file=sys.stderr ) + if self.chat_handler is not None: + if isinstance(self.chat_handler, llama_chat_format.Llava15ChatHandler): + self.chat_handler.initialize_mtmd_context(self) + self._sampler = None @property @@ -551,6 +495,13 @@ def ctx(self) -> llama_cpp.llama_context_p: @property def model(self) -> llama_cpp.llama_model_p: return self._model.model + + @property + def n_layer(self) -> int: + return self._model.n_layer() + + def dev_layer(self, il: int) -> internals.LlamaBackendDev: + return self._model.dev_layer(il) @property def _input_ids(self) -> npt.NDArray[np.intc]: @@ -635,7 +586,7 @@ def eval(self, tokens: Sequence[int]): Args: tokens: The list of tokens to evaluate. """ - self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1) + self._ctx.kv_self_seq_rm(-1, self.n_tokens, -1) for i in range(0, len(tokens), self.n_batch): batch = tokens[i : min(len(tokens), i + self.n_batch)] n_past = self.n_tokens @@ -943,7 +894,7 @@ def generate( if sample_idx < self.n_tokens and token != self._input_ids[sample_idx]: self.n_tokens = sample_idx - self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1) + self._ctx.kv_self_seq_rm(-1, self.n_tokens, -1) break if self.draft_model is not None: @@ -1039,7 +990,7 @@ def embed( data: Union[List[List[float]], List[List[List[float]]]] = [] def decode_batch(seq_sizes: List[int]): - llama_cpp.llama_kv_cache_clear(self._ctx.ctx) + llama_cpp.llama_kv_self_clear(self._ctx.ctx) self._ctx.decode(self._batch) self._batch.reset() @@ -1110,7 +1061,7 @@ def decode_batch(seq_sizes: List[int]): output = data[0] if isinstance(input, str) else data - llama_cpp.llama_kv_cache_clear(self._ctx.ctx) + llama_cpp.llama_kv_self_clear(self._ctx.ctx) self.reset() if return_count: @@ -1118,6 +1069,50 @@ def decode_batch(seq_sizes: List[int]): else: return output + def _create_chunk( + self, + completion_id: str, + created: int, + model_name: str, + text: str, + logprobs_or_none: Union[Optional[CompletionLogprobs], None], + index: int, + finish_reason: Union[str, None], + usage: Optional[Dict[str, Any]] = None, + ) -> CreateCompletionStreamResponse: + """Create chunks for streaming API, depending on whether usage is requested or not.""" + if usage is not None: + return { + "id": completion_id, + "object": "text_completion", + "created": created, + "model": model_name, + "choices": [ + { + "text": text, + "index": index, + "logprobs": logprobs_or_none, + "finish_reason": finish_reason, + } + ], + "usage": usage, + } + else: + return { + "id": completion_id, + "object": "text_completion", + "created": created, + "model": model_name, + "choices": [ + { + "text": text, + "index": index, + "logprobs": logprobs_or_none, + "finish_reason": finish_reason, + } + ], + } + def _create_completion( self, prompt: Union[str, List[int]], @@ -1444,24 +1439,20 @@ def logit_bias_processor( "top_logprobs": [top_logprob], } returned_tokens += 1 - yield { - "id": completion_id, - "object": "text_completion", - "created": created, - "model": model_name, - "choices": [ - { - "text": self.detokenize( - [token], - prev_tokens=prompt_tokens - + completion_tokens[:returned_tokens], - ).decode("utf-8", errors="ignore"), - "index": 0, - "logprobs": logprobs_or_none, - "finish_reason": None, - } - ], - } + yield self._create_chunk( + completion_id=completion_id, + created=created, + model_name=model_name, + text=self.detokenize( + [token], + prev_tokens=prompt_tokens + + completion_tokens[:returned_tokens], + ).decode("utf-8", errors="ignore"), + logprobs_or_none=logprobs_or_none, + index=0, + finish_reason=None, + usage=None, + ) else: while len(remaining_tokens) > 0: decode_success = False @@ -1490,20 +1481,16 @@ def logit_bias_processor( remaining_tokens = remaining_tokens[i:] returned_tokens += i - yield { - "id": completion_id, - "object": "text_completion", - "created": created, - "model": model_name, - "choices": [ - { - "text": ts, - "index": 0, - "logprobs": None, - "finish_reason": None, - } - ], - } + yield self._create_chunk( + completion_id=completion_id, + created=created, + model_name=model_name, + text=ts, + logprobs_or_none=None, + index=0, + finish_reason=None, + usage=None, + ) if len(completion_tokens) >= max_tokens: text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens) @@ -1582,54 +1569,51 @@ def logit_bias_processor( if token_end_position == end - 1: break returned_tokens += 1 - yield { - "id": completion_id, - "object": "text_completion", - "created": created, - "model": model_name, - "choices": [ - { - "text": last_text[ - : len(last_text) - (token_end_position - end) - ].decode("utf-8", errors="ignore"), - "index": 0, - "logprobs": logprobs_or_none, - "finish_reason": None, - } - ], - } + yield self._create_chunk( + completion_id=completion_id, + created=created, + model_name=model_name, + text=last_text[ + : len(last_text) - (token_end_position - end) + ].decode("utf-8", errors="ignore"), + logprobs_or_none=logprobs_or_none, + index=0, + finish_reason=None, + usage=None, + ) break returned_tokens += 1 - yield { - "id": completion_id, - "object": "text_completion", - "created": created, - "model": model_name, - "choices": [ - { - "text": self.detokenize([token]).decode( - "utf-8", errors="ignore" - ), - "index": 0, - "logprobs": logprobs_or_none, - "finish_reason": None, - } - ], - } - yield { - "id": completion_id, - "object": "text_completion", - "created": created, - "model": model_name, - "choices": [ - { - "text": "", - "index": 0, - "logprobs": None, - "finish_reason": finish_reason, - } - ], + yield self._create_chunk( + completion_id=completion_id, + created=created, + model_name=model_name, + text=self.detokenize([token]).decode( + "utf-8", errors="ignore" + ), + logprobs_or_none=logprobs_or_none, + index=0, + finish_reason=None, + usage=None, + ) + + # Final streaming chunk with both finish_reason and usage + usage = { + "prompt_tokens": len(prompt_tokens), + "completion_tokens": returned_tokens, + "total_tokens": len(prompt_tokens) + returned_tokens, } + + yield self._create_chunk( + completion_id=completion_id, + created=created, + model_name=model_name, + text="", + logprobs_or_none=None, + index=0, + finish_reason=finish_reason, + usage=usage, + ) + if self.cache: if self.verbose: print("Llama._create_completion: cache save", file=sys.stderr) @@ -2359,6 +2343,262 @@ def from_pretrained( **kwargs, ) + def _create_context( + self, + *, + n_ctx: int = 512, + n_batch: int = 512, + n_ubatch: int = 512, + n_threads: Optional[int] = None, + n_threads_batch: Optional[int] = None, + rope_scaling_type: Optional[ + int + ] = llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, + pooling_type: int = llama_cpp.LLAMA_POOLING_TYPE_UNSPECIFIED, + rope_freq_base: float = 0.0, + rope_freq_scale: float = 0.0, + yarn_ext_factor: float = -1.0, + yarn_attn_factor: float = 1.0, + yarn_beta_fast: float = 32.0, + yarn_beta_slow: float = 1.0, + yarn_orig_ctx: int = 0, + logits_all: bool = False, + embedding: bool = False, + offload_kqv: bool = True, + flash_attn: bool = False, + # Sampling Params + no_perf: bool = False, + last_n_tokens_size: int = 64, + type_k: Optional[GGMLType] = None, + type_v: Optional[GGMLType] = None, + state: Optional[LlamaState] = None, + ) -> None: + """Free the existing context and create a new one with specified parameters. + + Args: + n_ctx: Text context size. If 0, value from model is used. + n_batch: Maximum batch size for llama_decode. + n_ubatch: Maximum physical batch size. + n_seq_max: Maximum number of sequences (distinct states for recurrent models). + n_threads: Number of threads to use for generation. + n_threads_batch: Number of threads to use for batch processing. + rope_scaling_type: RoPE scaling type from llama_rope_scaling_type enum. + pooling_type: Whether to pool embedding results by sequence id. + attention_type: Attention type to use for embeddings. + rope_freq_base: RoPE base frequency, 0 = from model. + rope_freq_scale: RoPE frequency scaling factor, 0 = from model. + yarn_ext_factor: YaRN extrapolation mix factor, negative = from model. + yarn_attn_factor: YaRN magnitude scaling factor. + yarn_beta_fast: YaRN low correction dim. + yarn_beta_slow: YaRN high correction dim. + yarn_orig_ctx: YaRN original context size. + defrag_thold: Defragment KV cache if holes/size > thold, < 0 disabled. + type_k: Data type for K cache. + type_v: Data type for V cache. + logits_all: Compute all logits in llama_decode (deprecated). + embeddings: Extract embeddings with logits. + offload_kqv: Offload KQV ops (including KV cache) to GPU. + flash_attn: Use flash attention. + no_perf: Disable performance timings. + last_n_tokens_size: Size of the last n tokens. + type_k: Data type for K cache. + type_v: Data type for V cache. + """ + # Create new context params with provided values + self.n_batch = min(n_ctx, n_batch) # ??? + self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1) + self.n_threads_batch = n_threads_batch or multiprocessing.cpu_count() + + # Context Params + self.context_params = llama_cpp.llama_context_default_params() + self.context_params.n_ctx = n_ctx + self.context_params.n_batch = self.n_batch + self.context_params.n_ubatch = min(self.n_batch, n_ubatch) + self.context_params.n_threads = self.n_threads + self.context_params.n_threads_batch = self.n_threads_batch + self.context_params.rope_scaling_type = ( + rope_scaling_type + if rope_scaling_type is not None + else llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED + ) + self.context_params.pooling_type = pooling_type + self.context_params.rope_freq_base = ( + rope_freq_base if rope_freq_base != 0.0 else 0 + ) + self.context_params.rope_freq_scale = ( + rope_freq_scale if rope_freq_scale != 0.0 else 0 + ) + self.context_params.yarn_ext_factor = ( + yarn_ext_factor if yarn_ext_factor != 0.0 else 0 + ) + self.context_params.yarn_attn_factor = ( + yarn_attn_factor if yarn_attn_factor != 0.0 else 0 + ) + self.context_params.yarn_beta_fast = ( + yarn_beta_fast if yarn_beta_fast != 0.0 else 0 + ) + self.context_params.yarn_beta_slow = ( + yarn_beta_slow if yarn_beta_slow != 0.0 else 0 + ) + self.context_params.yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0 + self.context_params.logits_all = ( + logits_all if self.draft_model is None else True + ) # Must be set to True for speculative decoding + self.context_params.embeddings = embedding # TODO: Rename to embeddings + self.context_params.offload_kqv = offload_kqv + self.context_params.flash_attn = flash_attn + # KV cache quantization + if type_k is not None: + self.context_params.type_k = type_k + if type_v is not None: + self.context_params.type_v = type_v + + self.context_params.no_perf = no_perf + self.last_n_tokens_size = last_n_tokens_size + + # Set the default value for the context and correct the batch + if n_ctx == 0: + n_ctx = self._model.n_ctx_train() + self.n_batch = min(n_ctx, n_batch) + self.context_params.n_ctx = self._model.n_ctx_train() + self.context_params.n_batch = self.n_batch + self.context_params.n_ubatch = min(self.n_batch, n_ubatch) + + self._ctx = self._stack.enter_context( + contextlib.closing( + internals.LlamaContext( + model=self._model, + params=self.context_params, + verbose=self.verbose, + ) + ) + ) + + if state is not None: + self.load_state(state) + + self._n_ctx = self.n_ctx() + self.n_tokens = 0 + + self.input_ids: npt.NDArray[np.intc] = np.ndarray((n_ctx,), dtype=np.intc) + self.scores: npt.NDArray[np.single] = np.ndarray( + (n_ctx if logits_all == True else n_batch, self._n_vocab), dtype=np.single + ) + + self._batch = self._stack.enter_context( + contextlib.closing( + internals.LlamaBatch( + n_tokens=self.n_batch, + embd=0, + n_seq_max=self.context_params.n_ctx, + verbose=self.verbose, + ) + ) + ) + + + + if self._ctx is None: + raise RuntimeError("Failed to create new context") + + def recreate_context( + self, + *, + n_ctx: int = 512, + n_batch: int = 512, + n_ubatch: int = 512, + n_threads: Optional[int] = None, + n_threads_batch: Optional[int] = None, + rope_scaling_type: Optional[ + int + ] = llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, + pooling_type: int = llama_cpp.LLAMA_POOLING_TYPE_UNSPECIFIED, + rope_freq_base: float = 0.0, + rope_freq_scale: float = 0.0, + yarn_ext_factor: float = -1.0, + yarn_attn_factor: float = 1.0, + yarn_beta_fast: float = 32.0, + yarn_beta_slow: float = 1.0, + yarn_orig_ctx: int = 0, + logits_all: bool = False, + embedding: bool = False, + offload_kqv: bool = True, + flash_attn: bool = False, + # Sampling Params + no_perf: bool = False, + last_n_tokens_size: int = 64, + type_k: Optional[GGMLType] = None, + type_v: Optional[GGMLType] = None, + ) -> None: + """Free the existing context and create a new one with specified parameters. + + Args: + n_ctx: Text context size. If 0, value from model is used. + n_batch: Maximum batch size for llama_decode. + n_ubatch: Maximum physical batch size. + n_seq_max: Maximum number of sequences (distinct states for recurrent models). + n_threads: Number of threads to use for generation. + n_threads_batch: Number of threads to use for batch processing. + rope_scaling_type: RoPE scaling type from llama_rope_scaling_type enum. + pooling_type: Whether to pool embedding results by sequence id. + attention_type: Attention type to use for embeddings. + rope_freq_base: RoPE base frequency, 0 = from model. + rope_freq_scale: RoPE frequency scaling factor, 0 = from model. + yarn_ext_factor: YaRN extrapolation mix factor, negative = from model. + yarn_attn_factor: YaRN magnitude scaling factor. + yarn_beta_fast: YaRN low correction dim. + yarn_beta_slow: YaRN high correction dim. + yarn_orig_ctx: YaRN original context size. + defrag_thold: Defragment KV cache if holes/size > thold, < 0 disabled. + type_k: Data type for K cache. + type_v: Data type for V cache. + logits_all: Compute all logits in llama_decode (deprecated). + embeddings: Extract embeddings with logits. + offload_kqv: Offload KQV ops (including KV cache) to GPU. + flash_attn: Use flash attention. + no_perf: Disable performance timings. + last_n_tokens_size: Size of the last n tokens. + type_k: Data type for K cache. + type_v: Data type for V cache. + """ + + current_state = self.save_state() + + if self._ctx is not None: + self._ctx.close() + self._ctx = None + + # Free existing context if it exists + self._create_context( + n_ctx=n_ctx, + n_batch=n_batch, + n_ubatch=min(n_batch, n_ubatch), + n_threads=n_threads, + n_threads_batch=n_threads_batch, + rope_scaling_type=rope_scaling_type, + pooling_type=pooling_type, + rope_freq_base=rope_freq_base, + rope_freq_scale=rope_freq_scale, + yarn_ext_factor=yarn_ext_factor, + yarn_attn_factor=yarn_attn_factor, + yarn_beta_fast=yarn_beta_fast, + yarn_beta_slow=yarn_beta_slow, + yarn_orig_ctx=yarn_orig_ctx, + logits_all=logits_all, + embedding=embedding, + offload_kqv=offload_kqv, + flash_attn=flash_attn, + no_perf=no_perf, + last_n_tokens_size=last_n_tokens_size, + type_k=type_k, + type_v=type_v, + state=current_state, + ) + + # Reapply any LoRA adapter if it exists + if self._lora_adapter is not None: + llama_cpp.llama_set_adapter_lora(self._ctx, self._lora_adapter, self.lora_scale) + class LlamaState: def __init__( diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 17575c7..2abe67c 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -7,6 +7,7 @@ import dataclasses import random import string +import warnings from contextlib import ExitStack from typing import ( @@ -28,9 +29,7 @@ import numpy as np import numpy.typing as npt -import llama_cpp.llama as llama -import llama_cpp.llama_types as llama_types -import llama_cpp.llama_grammar as llama_grammar +from llama_cpp import llama, llama_grammar, llama_types from ._logger import logger from ._utils import suppress_stdout_stderr, Singleton @@ -348,6 +347,7 @@ def _convert_text_completion_chunks_to_chat( "finish_reason": chunk["choices"][0]["finish_reason"], } ], + "usage": chunk.get("usage") if "usage" in chunk else None, } @@ -432,7 +432,7 @@ def _stream_response_to_function_stream( created = chunk["created"] model = chunk["model"] tool_id = "call_" + "_0_" + tool_name + "_" + chunk["id"] - yield { + response = { "id": id_, "object": "chat.completion.chunk", "created": created, @@ -451,7 +451,11 @@ def _stream_response_to_function_stream( } ], } - yield { + if "usage" in chunk: + response["usage"] = chunk["usage"] + yield response + + response = { "id": "chat" + chunk["id"], "object": "chat.completion.chunk", "created": chunk["created"], @@ -485,10 +489,14 @@ def _stream_response_to_function_stream( } ], } + if "usage" in chunk: + response["usage"] = chunk["usage"] + yield response first = False continue + assert tool_id is not None - yield { + response = { "id": "chat" + chunk["id"], "object": "chat.completion.chunk", "created": chunk["created"], @@ -520,9 +528,12 @@ def _stream_response_to_function_stream( } ], } + if "usage" in chunk: + response["usage"] = chunk["usage"] + yield response if id_ is not None and created is not None and model is not None: - yield { + response = { "id": id_, "object": "chat.completion.chunk", "created": created, @@ -541,6 +552,9 @@ def _stream_response_to_function_stream( } ], } + if "usage" in chunk: + response["usage"] = chunk["usage"] + yield response return _stream_response_to_function_stream(chunks) @@ -2121,6 +2135,7 @@ def generate_streaming(tools, functions, function_call, prompt): }, } ], + usage=chunk["usage"] if "usage" in chunk else None, ) first = False if tools is not None: @@ -2161,6 +2176,7 @@ def generate_streaming(tools, functions, function_call, prompt): }, } ], + usage=chunk["usage"] if "usage" in chunk else None, ) # Yield tool_call/function_call stop message yield llama_types.CreateChatCompletionStreamResponse( @@ -2183,6 +2199,7 @@ def generate_streaming(tools, functions, function_call, prompt): }, } ], + usage=chunk["usage"] if "usage" in chunk else None, ) # If "auto" or no tool_choice/function_call elif isinstance(function_call, str) and function_call == "auto": @@ -2218,6 +2235,7 @@ def generate_streaming(tools, functions, function_call, prompt): "finish_reason": None, } ], + usage=chunk["usage"] if "usage" in chunk else None, ) else: prompt += f"{function_name}\n<|content|>" @@ -2263,6 +2281,7 @@ def generate_streaming(tools, functions, function_call, prompt): }, } ], + usage=chunk["usage"] if "usage" in chunk else None, ) # Generate content stops = [RECIPIENT_TOKEN, STOP_TOKEN] @@ -2300,6 +2319,7 @@ def generate_streaming(tools, functions, function_call, prompt): }, } ], + usage=chunk["usage"] if "usage" in chunk else None, ) is_end = False elif chunk["choices"][0]["text"] == "\n": @@ -2329,6 +2349,7 @@ def generate_streaming(tools, functions, function_call, prompt): }, } ], + usage=chunk["usage"] if "usage" in chunk else None, ) # Check whether the model wants to generate another turn if ( @@ -2361,6 +2382,7 @@ def generate_streaming(tools, functions, function_call, prompt): "finish_reason": "stop", } ], + usage=chunk["usage"] if "usage" in chunk else None, ) break else: @@ -2410,6 +2432,7 @@ def generate_streaming(tools, functions, function_call, prompt): }, } ], + usage=chunk["usage"] if "usage" in chunk else None, ) prompt += completion_text.strip() grammar = None @@ -2449,6 +2472,7 @@ def generate_streaming(tools, functions, function_call, prompt): }, } ], + usage=chunk["usage"] if "usage" in chunk else None, ) break @@ -2651,7 +2675,7 @@ def generate_streaming(tools, functions, function_call, prompt): class Llava15ChatHandler: DEFAULT_SYSTEM_MESSAGE: Optional[str] = ( - "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." + "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." ) CHAT_FORMAT = ( @@ -2689,71 +2713,259 @@ class Llava15ChatHandler: "{% endif %}" ) - def __init__(self, clip_model_path: str, verbose: bool = True): - import llama_cpp.llava_cpp as llava_cpp - + def __init__(self, clip_model_path: str, llama_model: Optional[llama.Llama] = None, verbose: bool = True): self.clip_model_path = clip_model_path self.verbose = verbose - - self._llava_cpp = llava_cpp # TODO: Fix + self._mtmd_cpp = None self._exit_stack = ExitStack() - self._last_image_embed: Optional[ - llava_cpp.CtypesPointer[llava_cpp.llava_image_embed] - ] = None - self._last_image_hash: Optional[int] = None + self._bitmap_manager = None + self.clip_ctx = None + self._params = None if not os.path.exists(clip_model_path): raise ValueError(f"Clip model path does not exist: {clip_model_path}") - with suppress_stdout_stderr(disable=self.verbose): - clip_ctx = self._llava_cpp.clip_model_load(self.clip_model_path.encode(), 0) + # Initialize MTMD context if model is provided + if llama_model is not None: + self.initialize_mtmd_context(llama_model) - if clip_ctx is None: - raise ValueError(f"Failed to load clip model: {clip_model_path}") + def initialize_mtmd_context(self, llama_model: llama.Llama): + """Initialize the MTMD context with a llama model.""" + import llama_cpp.mtmd_cpp as mtmd_cpp + self._mtmd_cpp = mtmd_cpp - self.clip_ctx = clip_ctx + with suppress_stdout_stderr(disable=self.verbose): + params = self._mtmd_cpp.mtmd_context_params_default() + params.use_gpu = True # TODO: Make configurable + params.verbosity = 1 if self.verbose else 0 + self.clip_ctx = self._mtmd_cpp.mtmd_init_from_file( + self.clip_model_path.encode(), + llama_model._model.model, + params + ) - def clip_free(): - with suppress_stdout_stderr(disable=self.verbose): - self._llava_cpp.clip_free(self.clip_ctx) + if self.clip_ctx is None: + raise ValueError(f"Failed to load clip model: {self.clip_model_path}") - self._exit_stack.callback(clip_free) + def mtmd_free(): + with suppress_stdout_stderr(disable=self.verbose): + self._mtmd_cpp.mtmd_free(self.clip_ctx) - def last_image_embed_free(): - with suppress_stdout_stderr(disable=self.verbose): - if self._last_image_embed is not None: - self._llava_cpp.llava_image_embed_free(self._last_image_embed) - self._last_image_embed = None + self._exit_stack.callback(mtmd_free) - self._exit_stack.callback(last_image_embed_free) + def __call__(self, *args, **kwargs): + if self.clip_ctx is None: + # Initialize MTMD context with the llama model from the first argument + if len(args) > 0 and isinstance(args[0], llama.Llama): + self.initialize_mtmd_context(args[0]) + else: + raise ValueError("MTMD context not initialized. Please call initialize_mtmd_context with a llama model first.") + return super().__call__(*args, **kwargs) def load_image(self, image_url: str) -> bytes: return self._load_image(image_url) - def _embed_image_bytes(self, image_bytes: bytes, n_threads_batch: int = 1): - if ( - self._last_image_embed is not None - and self._last_image_hash is not None - and hash(image_bytes) == self._last_image_hash - ): - return self._last_image_embed - with suppress_stdout_stderr(disable=self.verbose): - # Free the previous image embed - if self._last_image_embed is not None: - self._llava_cpp.llava_image_embed_free(self._last_image_embed) - self._last_image_embed = None - self._last_image_hash = None - embed = self._llava_cpp.llava_image_embed_make_with_bytes( - self.clip_ctx, - n_threads_batch, - (ctypes.c_uint8 * len(image_bytes)).from_buffer( - bytearray(image_bytes) - ), - len(image_bytes), + def eval_image(self, llama: llama.Llama, image_url: str): + if self.clip_ctx is None: + self.initialize_mtmd_context(llama) + + image_bytes = self.load_image(image_url) + + # Create bitmap manager if not exists + if self._bitmap_manager is None: + self._bitmap_manager = self._mtmd_cpp.BitmapManager() + + # Create bitmap from bytes + if not self._bitmap_manager.add_from_memory(self.clip_ctx, image_bytes): + raise ValueError("Failed to create bitmap from image bytes") + + # Create input chunks for the bitmap + chunks = self._mtmd_cpp.mtmd_input_chunks_init() + if chunks is None: + raise ValueError("Failed to create input chunks") + + # Create input text with media marker + # Get media marker from context params + params = self._mtmd_cpp.mtmd_context_params_default() + text = self._mtmd_cpp.mtmd_input_text() + text.text = params.media_marker if params.media_marker else self._mtmd_cpp.mtmd_default_marker() + text.add_special = False + text.parse_special = True + + # Tokenize with bitmap + if self._mtmd_cpp.mtmd_tokenize(self.clip_ctx, chunks, text, self._bitmap_manager.c_ptr(), len(self._bitmap_manager.entries)) != 0: + self._mtmd_cpp.mtmd_input_chunks_free(chunks) + raise ValueError("Failed to tokenize image") + + # Get new n_past after evaluation + n_past = ctypes.c_int(llama.n_tokens) + n_past_p = ctypes.pointer(n_past) + + # Evaluate chunks + if self._mtmd_cpp.mtmd_helper_eval_chunks( + self.clip_ctx, + llama.ctx, + chunks, + llama.n_tokens, + 0, # seq_id + llama.n_batch, + True, # logits_last + n_past_p + ) != 0: + self._mtmd_cpp.mtmd_input_chunks_free(chunks) + raise ValueError("Failed to evaluate chunks") + + # Update n_tokens + llama.input_ids[llama.n_tokens : n_past.value] = -1 + llama.n_tokens = n_past.value + + # Cleanup + self._mtmd_cpp.mtmd_input_chunks_free(chunks) + self._bitmap_manager.clear() + + @staticmethod + def _load_image(image_url: str) -> bytes: + if image_url.startswith("data:"): + import base64 + image_bytes = base64.b64decode(image_url.split(",")[1]) + return image_bytes + elif image_url.startswith("http://") or image_url.startswith("https://"): + import urllib.request + with urllib.request.urlopen(image_url) as f: + image_bytes = f.read() + return image_bytes + else: + # Handle local file path + with open(image_url, 'rb') as f: + image_bytes = f.read() + return image_bytes + + @staticmethod + def get_image_urls(messages: List[llama_types.ChatCompletionRequestMessage]): + image_urls: List[str] = [] + for message in messages: + if message["role"] == "user": + if message["content"] is None: + continue + for content in message["content"]: + if isinstance(content, dict) and "type" in content: + if content["type"] == "image_url": + if ( + isinstance(content["image_url"], dict) + and "url" in content["image_url"] + ): + image_urls.append(content["image_url"]["url"]) + else: + image_urls.append(content["image_url"]) + return image_urls + + @staticmethod + def split_text_on_image_urls(text: str, image_urls: List[str]): + def find_first(s: str, substrs: List[str]): + for i, substr in enumerate(substrs): + pos = s.find(substr) + if pos != -1: + return pos, i + return None, None + + split_text: List[Tuple[Literal["text", "image_url"], str]] = [] + remaining = text + while remaining: + # Find first image_url + pos, i = find_first(remaining, image_urls) + if pos is not None and i is not None: + if pos > 0: + split_text.append(("text", remaining[:pos])) + split_text.append(("image_url", image_urls[i])) + remaining = remaining[pos + len(image_urls[i]) :] + else: + split_text.append(("text", remaining)) + remaining = "" + return split_text + + @classmethod + def from_pretrained( + cls, + repo_id: str, + filename: Optional[str], + local_dir: Optional[Union[str, os.PathLike[str]]] = None, + local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto", + cache_dir: Optional[Union[str, os.PathLike[str]]] = None, + **kwargs: Any, + ) -> "Llava15ChatHandler": + import fnmatch + from pathlib import Path + + try: + from huggingface_hub import hf_hub_download, HfFileSystem # type: ignore + from huggingface_hub.utils import validate_repo_id # type: ignore + except ImportError: + raise ImportError( + "Llama.from_pretrained requires the huggingface-hub package. " + "You can install it with `pip install huggingface-hub`." + ) + + validate_repo_id(repo_id) + + hffs = HfFileSystem() + + files = [ + file["name"] if isinstance(file, dict) else file + for file in hffs.ls(repo_id) # type: ignore + ] + + # split each file into repo_id, subfolder, filename + file_list: List[str] = [] + for file in files: + rel_path = Path(file).relative_to(repo_id) + file_list.append(str(rel_path)) + + matching_files = [file for file in file_list if fnmatch.fnmatch(file, filename)] # type: ignore + + if len(matching_files) == 0: + raise ValueError( + f"No file found in {repo_id} that match {filename}\n\n" + f"Available Files:\n{json.dumps(file_list)}" + ) + + if len(matching_files) > 1: + raise ValueError( + f"Multiple files found in {repo_id} matching {filename}\n\n" + f"Available Files:\n{json.dumps(files)}" + ) + + (matching_file,) = matching_files + + subfolder = str(Path(matching_file).parent) + filename = Path(matching_file).name + + # download the file + hf_hub_download( + repo_id=repo_id, + filename=filename, + subfolder=subfolder, + local_dir=cast(Union[str, Path, None], local_dir), + local_dir_use_symlinks=local_dir_use_symlinks, + cache_dir=cast(Union[str, Path, None], cache_dir), + ) + + if local_dir is None: + model_path = hf_hub_download( + repo_id=repo_id, + filename=filename, + subfolder=subfolder, + local_dir=local_dir, + local_dir_use_symlinks=local_dir_use_symlinks, + cache_dir=cast(Union[str, Path, None], cache_dir), + local_files_only=True, ) - self._last_image_embed = embed - self._last_image_hash = hash(image_bytes) - return embed + else: + model_path = os.path.join(local_dir, filename) + + return cls( + clip_model_path=model_path, + **kwargs, + ) def __call__( self, @@ -2820,10 +3032,9 @@ def __call__( if self.verbose: print(text, file=sys.stderr) - # Evaluate prompt llama.reset() - llama._ctx.kv_cache_clear() + llama._ctx.kv_self_clear() for type_, value in split_text: if type_ == "text": tokens = llama.tokenize( @@ -2835,24 +3046,7 @@ def __call__( ) llama.eval(tokens) else: - image_bytes = self.load_image(value) - embed = self._embed_image_bytes(image_bytes, llama.context_params.n_threads_batch) - if llama.n_tokens + embed.contents.n_image_pos > llama.n_ctx(): - raise ValueError( - f"Prompt exceeds n_ctx: {llama.n_tokens + embed.contents.n_image_pos} > {llama.n_ctx()}" - ) - n_past = ctypes.c_int(llama.n_tokens) - n_past_p = ctypes.pointer(n_past) - with suppress_stdout_stderr(disable=self.verbose): - self._llava_cpp.llava_eval_image_embed( - llama.ctx, - embed, - llama.n_batch, - n_past_p, - ) - # Required to avoid issues with hf tokenizer - llama.input_ids[llama.n_tokens : n_past.value] = -1 - llama.n_tokens = n_past.value + self.eval_image(llama, value) # Get prompt tokens to avoid a cache miss prompt = llama.input_ids[: llama.n_tokens].tolist() @@ -2938,203 +3132,61 @@ def __call__( ) return _convert_completion_to_chat(completion_or_chunks, stream=stream) - @staticmethod - def _load_image(image_url: str) -> bytes: - # TODO: Add Pillow support for other image formats beyond (jpg, png) - if image_url.startswith("data:"): - import base64 - image_bytes = base64.b64decode(image_url.split(",")[1]) - return image_bytes - else: - import urllib.request +class ObsidianChatHandler(Llava15ChatHandler): + # Prompt Format + # The model followed ChatML format. However, with ### as the seperator - with urllib.request.urlopen(image_url) as f: - image_bytes = f.read() - return image_bytes + # <|im_start|>user + # What is this sign about?\n + # ### + # <|im_start|>assistant + # The sign is about bullying, and it is placed on a black background with a red background. + # ### - @staticmethod - def get_image_urls(messages: List[llama_types.ChatCompletionRequestMessage]): - image_urls: List[str] = [] - for message in messages: - if message["role"] == "user": - if message["content"] is None: - continue - for content in message["content"]: - if isinstance(content, dict) and "type" in content: - if content["type"] == "image_url": - if ( - isinstance(content["image_url"], dict) - and "url" in content["image_url"] - ): - image_urls.append(content["image_url"]["url"]) - else: - image_urls.append(content["image_url"]) - return image_urls - - @staticmethod - def split_text_on_image_urls(text: str, image_urls: List[str]): - def find_first(s: str, substrs: List[str]): - for i, substr in enumerate(substrs): - pos = s.find(substr) - if pos != -1: - return pos, i - return None, None - - split_text: List[Tuple[Literal["text", "image_url"], str]] = [] - remaining = text - while remaining: - # Find first image_url - pos, i = find_first(remaining, image_urls) - if pos is not None and i is not None: - if pos > 0: - split_text.append(("text", remaining[:pos])) - split_text.append(("image_url", image_urls[i])) - remaining = remaining[pos + len(image_urls[i]) :] - else: - split_text.append(("text", remaining)) - remaining = "" - return split_text - - @classmethod - def from_pretrained( - cls, - repo_id: str, - filename: Optional[str], - local_dir: Optional[Union[str, os.PathLike[str]]] = None, - local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto", - cache_dir: Optional[Union[str, os.PathLike[str]]] = None, - **kwargs: Any, - ) -> "Llava15ChatHandler": - import fnmatch - from pathlib import Path - - try: - from huggingface_hub import hf_hub_download, HfFileSystem # type: ignore - from huggingface_hub.utils import validate_repo_id # type: ignore - except ImportError: - raise ImportError( - "Llama.from_pretrained requires the huggingface-hub package. " - "You can install it with `pip install huggingface-hub`." - ) - - validate_repo_id(repo_id) - - hffs = HfFileSystem() - - files = [ - file["name"] if isinstance(file, dict) else file - for file in hffs.ls(repo_id) # type: ignore - ] - - # split each file into repo_id, subfolder, filename - file_list: List[str] = [] - for file in files: - rel_path = Path(file).relative_to(repo_id) - file_list.append(str(rel_path)) - - matching_files = [file for file in file_list if fnmatch.fnmatch(file, filename)] # type: ignore - - if len(matching_files) == 0: - raise ValueError( - f"No file found in {repo_id} that match {filename}\n\n" - f"Available Files:\n{json.dumps(file_list)}" - ) - - if len(matching_files) > 1: - raise ValueError( - f"Multiple files found in {repo_id} matching {filename}\n\n" - f"Available Files:\n{json.dumps(files)}" - ) - - (matching_file,) = matching_files - - subfolder = str(Path(matching_file).parent) - filename = Path(matching_file).name - - # download the file - hf_hub_download( - repo_id=repo_id, - filename=filename, - subfolder=subfolder, - local_dir=cast(Union[str, Path, None], local_dir), - local_dir_use_symlinks=local_dir_use_symlinks, - cache_dir=cast(Union[str, Path, None], cache_dir), - ) - - if local_dir is None: - model_path = hf_hub_download( - repo_id=repo_id, - filename=filename, - subfolder=subfolder, - local_dir=local_dir, - local_dir_use_symlinks=local_dir_use_symlinks, - cache_dir=cast(Union[str, Path, None], cache_dir), - local_files_only=True, - ) - else: - model_path = os.path.join(local_dir, filename) - - return cls( - clip_model_path=model_path, - **kwargs, - ) - - -class ObsidianChatHandler(Llava15ChatHandler): - # Prompt Format - # The model followed ChatML format. However, with ### as the seperator - - # <|im_start|>user - # What is this sign about?\n - # ### - # <|im_start|>assistant - # The sign is about bullying, and it is placed on a black background with a red background. - # ### - - CHAT_FORMAT = ( - "{% for message in messages %}" - # System message - "{% if message.role == 'system' %}" - "<|im_start|>system\n" - "{{ message.content }}\n" - "###\n" - "{% endif %}" - # User message - "{% if message.role == 'user' %}" - "<|im_start|>user\n" - "{% if message.content is string %}" - "{{ message.content }}" - "{% endif %}" - "{% if message.content is iterable %}" - "{% for content in message.content %}" - "{% if content.type == 'image_url' and content.image_url is string %}" - "{{ content.image_url }}" - "{% endif %}" - "{% if content.type == 'image_url' and content.image_url is mapping %}" - "{{ content.image_url.url }}" - "{% endif %}" - "{% endfor %}" - "{% for content in message.content %}" - "{% if content.type == 'text' %}" - "{{ content.text }}" - "{% endif %}" - "{% endfor %}" - "{% endif %}" - "###\n" - "{% endif %}" - # Assistant message - "{% if message.role == 'assistant' %}" - "<|im_start|>assistant\n" - "{{ message.content }}" - "###\n" - "{% endif %}" - "{% endfor %}" - # Generation prompt - "{% if add_generation_prompt %}" - "<|im_start|>assistant\n" - "{% endif %}" - ) + CHAT_FORMAT = ( + "{% for message in messages %}" + # System message + "{% if message.role == 'system' %}" + "<|im_start|>system\n" + "{{ message.content }}\n" + "###\n" + "{% endif %}" + # User message + "{% if message.role == 'user' %}" + "<|im_start|>user\n" + "{% if message.content is string %}" + "{{ message.content }}" + "{% endif %}" + "{% if message.content is iterable %}" + "{% for content in message.content %}" + "{% if content.type == 'image_url' and content.image_url is string %}" + "{{ content.image_url }}" + "{% endif %}" + "{% if content.type == 'image_url' and content.image_url is mapping %}" + "{{ content.image_url.url }}" + "{% endif %}" + "{% endfor %}" + "{% for content in message.content %}" + "{% if content.type == 'text' %}" + "{{ content.text }}" + "{% endif %}" + "{% endfor %}" + "{% endif %}" + "###\n" + "{% endif %}" + # Assistant message + "{% if message.role == 'assistant' %}" + "<|im_start|>assistant\n" + "{{ message.content }}" + "###\n" + "{% endif %}" + "{% endfor %}" + # Generation prompt + "{% if add_generation_prompt %}" + "<|im_start|>assistant\n" + "{% endif %}" + ) class MoondreamChatHandler(Llava15ChatHandler): @@ -3373,6 +3425,229 @@ class MiniCPMv26ChatHandler(Llava15ChatHandler): ) +class Gemma3ChatHandler(Llava15ChatHandler): + # Chat Format: + # 'user\n{system_prompt}\n\n{prompt}\nmodel\n' + + DEFAULT_SYSTEM_MESSAGE = None + + CHAT_FORMAT = ( + "{{ '' }}" + "{%- if messages[0]['role'] == 'system' -%}" + "{%- if messages[0]['content'] is string -%}" + "{%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}" + "{%- else -%}" + "{%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}" + "{%- endif -%}" + "{%- set loop_messages = messages[1:] -%}" + "{%- else -%}" + "{%- set first_user_prefix = \"\" -%}" + "{%- set loop_messages = messages -%}" + "{%- endif -%}" + "{%- for message in loop_messages -%}" + "{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}" + "{{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}" + "{%- endif -%}" + "{%- if (message['role'] == 'assistant') -%}" + "{%- set role = \"model\" -%}" + "{%- else -%}" + "{%- set role = message['role'] -%}" + "{%- endif -%}" + "{{ '' + role + '\n' + (first_user_prefix if loop.first else \"\") }}" + "{%- if message['content'] is string -%}" + "{{ message['content'] | trim }}" + "{%- elif message['content'] is iterable -%}" + "{%- for item in message['content'] -%}" + "{%- if item['type'] == 'image_url' -%}" + "{{ '' }}" + "{%- elif item['type'] == 'text' -%}" + "{{ item['text'] | trim }}" + "{%- endif -%}" + "{%- endfor -%}" + "{%- else -%}" + "{{ raise_exception(\"Invalid content type\") }}" + "{%- endif -%}" + "{{ '\n' }}" + "{%- endfor -%}" + "{%- if add_generation_prompt -%}" + "{{ 'model\n' }}" + "{%- endif -%}" + ) + + @staticmethod + def split_text_on_image_urls(text: str, image_urls: List[str]): + split_text: List[Tuple[Literal["text", "image_url"], str]] = [] + copied_urls = image_urls[:] + remaining = text + image_placeholder = "" + + while remaining: + # Find placeholder + pos = remaining.find(image_placeholder) + if pos != -1: + assert len(copied_urls) > 0 + if pos > 0: + split_text.append(("text", remaining[:pos])) + split_text.append(("text", "\n\n")) + split_text.append(("image_url", copied_urls.pop(0))) + split_text.append(("text", "\n\n")) + remaining = remaining[pos + len(image_placeholder):] + else: + assert len(copied_urls) == 0 + split_text.append(("text", remaining)) + remaining = "" + return split_text + + +def _accumulate_chunks( + chunks_iterator: Iterator[llama_types.CreateCompletionStreamResponse], + chunks_list: List[llama_types.CreateCompletionStreamResponse], +) -> Iterator[llama_types.CreateCompletionStreamResponse]: + for chunk in chunks_iterator: + chunks_list.append(chunk) + yield chunk + + +def _convert_chunks_to_completion( + chunks: List[llama_types.CreateCompletionStreamResponse], +) -> llama_types.CreateCompletionResponse: + """Convert a list of completion chunks to a completion.""" + # Accumulate completion response values + text: str = "" + finish_reason: Optional[str] = None + logprobs: Optional[llama_types.CompletionLogprobs] = None + prompt_tokens = 0 + completion_tokens = 0 + total_tokens = 0 + completion_id: Optional[str] = None + completion_model: Optional[str] = None + completion_created: Optional[int] = None + for chunk in chunks: + # Extract the id, model, and created values from the first chunk + if completion_id is None: + completion_id = chunk["id"] + completion_model = chunk["model"] + completion_created = chunk["created"] + # Extract the usage if present in the chunk + usage = chunk.get("usage") + if usage: + prompt_tokens += usage.get("prompt_tokens", 0) + completion_tokens += usage.get("completion_tokens", 0) + total_tokens += usage.get("total_tokens", 0) + # Accumulate the chunk text + choice = chunk["choices"][0] + text += choice.get("text", "") + # Extract the finish_reason and logprobs if present in the chunk + if choice.get("finish_reason"): + finish_reason = choice["finish_reason"] + if choice.get("logprobs"): + logprobs = choice["logprobs"] + # Create the completion response + completion: llama_types.CreateCompletionResponse = { + "id": completion_id or "unknown_id", + "object": "text_completion", + "created": completion_created or 0, + "model": completion_model or "unknown_model", + "choices": [ + { + "text": text, + "index": 0, + "logprobs": logprobs, # TODO: Improve accumulation of logprobs + "finish_reason": finish_reason, # type: ignore[typeddict-item] + } + ], + } + # Add usage section if present in the chunks + if (prompt_tokens + completion_tokens + total_tokens) > 0: + completion["usage"] = { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + } + return completion + + +def _stream_tool_calls( + llama: llama.Llama, + prompt: str, + tools: List[llama_types.ChatCompletionTool], + tool_name: str, + completion_kwargs: dict[str, Any], + follow_up_gbnf_tool_grammar: str, +) -> Iterator[llama_types.CreateChatCompletionStreamResponse]: + # Generate a tool call completions + tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None) + completions: List[llama_types.CreateCompletionResponse] = [] + completions_tool_name: List[str] = [] + finish_reason_chat_chunk = None + while tool is not None and len(completions) <= 16: + # Generate the parameter values for the selected tool + prompt += f"functions.{tool_name}:\n" + try: + grammar = llama_grammar.LlamaGrammar.from_json_schema( + json.dumps(tool["function"]["parameters"]), verbose=llama.verbose + ) + except Exception as e: + warnings.warn( + f"Failed to parse function body as JSON schema, falling back to default grammar\n\n{e}", + category=RuntimeWarning, + stacklevel=2, + ) + grammar = llama_grammar.LlamaGrammar.from_string( + llama_grammar.JSON_GBNF, verbose=llama.verbose + ) + completion_or_chunks = llama.create_completion( + prompt=prompt, + **{ + **completion_kwargs, + "max_tokens": None, + "grammar": grammar, + }, + ) + chunks: List[llama_types.CreateCompletionResponse] = [] + chat_chunks = _convert_completion_to_chat_function( + tool_name, + _accumulate_chunks(completion_or_chunks, chunks), # type: ignore[arg-type] + stream=True, + ) + for chat_chunk in chat_chunks: + # Don't return the finish_reason chunk + if chat_chunk["choices"] and chat_chunk["choices"][0].get("finish_reason"): + finish_reason_chat_chunk = chat_chunk + break + # Update this tool call's index + if chat_chunk["choices"] and chat_chunk["choices"][0]["delta"].get("tool_calls"): + chat_chunk["choices"][0]["delta"]["tool_calls"][0]["index"] = len(completions) + yield chat_chunk + completion = _convert_chunks_to_completion(chunks) + completions.append(completion) + completions_tool_name.append(tool_name) + prompt += completion["choices"][0]["text"] + prompt += "\n" + # Determine whether to call another tool or stop + response = cast( + llama_types.CreateCompletionResponse, + llama.create_completion( + prompt=prompt, + **{ + **completion_kwargs, + "temperature": 0, + "stream": False, + "stop": [*completion_kwargs["stop"], ":", ""], + "max_tokens": None, + "grammar": llama_grammar.LlamaGrammar.from_string( + follow_up_gbnf_tool_grammar, verbose=llama.verbose + ), + }, + ), + ) + tool_name = response["choices"][0]["text"][len("functions.") :] + tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None) + # Yield the finish_reason chunk + if finish_reason_chat_chunk is not None: + yield finish_reason_chat_chunk + + @register_chat_completion_handler("chatml-function-calling") def chatml_function_calling( llama: llama.Llama, @@ -3402,7 +3677,7 @@ def chatml_function_calling( grammar: Optional[llama.LlamaGrammar] = None, logprobs: Optional[bool] = None, top_logprobs: Optional[int] = None, - **kwargs, # type: ignore + **kwargs: Any, ) -> Union[ llama_types.CreateChatCompletionResponse, Iterator[llama_types.CreateChatCompletionStreamResponse], @@ -3416,18 +3691,21 @@ def chatml_function_calling( "{% if tool_calls %}" "\n\nYou have access to the following functions:\n" "{% for tool in tools %}" + '\n{% if tool.function.get("description") %}/* {{ tool.function.description | trim }} */{% endif %}' "\nfunctions.{{ tool.function.name }}:\n" "{{ tool.function.parameters | tojson }}" "\n{% endfor %}" - "\n\nYou can respond to users messages with either a single message or one or more function calls." - "\n\nTo respond with a message begin the message with 'message:', use the following format:" + "\nYou must respond to user messages with either a single message or with one or more function calls." + "\n\nTo respond with a message use the following format:" "\n\nmessage:" "\n" - "\n\nTo respond with one or more function calls begin the message with 'functions.:', use the following format:" - "\n\nfunctions.:" + "\n\nTo respond with one or more function calls use the following format:" + "\n\n" + "\nfunctions.:" '\n{ "arg1": "value1", "arg2": "value2" }' "\nfunctions.:" '\n{ "arg1": "value1", "arg2": "value2" }' + "\n" "{% endif %}" "<|im_end|>\n" "{% endif %}" @@ -3438,7 +3716,7 @@ def chatml_function_calling( "{% endif %}" # Assistant message "{% if message.role == 'assistant' %}" - ## Reglar message + ## Regular message "{% if message.content and message.content | length > 0 %}" "{% if tool_calls %}" "message:\n" @@ -3465,35 +3743,55 @@ def chatml_function_calling( # Convert legacy functions to tools if functions is not None: - tools = [ - { - "type": "function", - "function": function, - } - for function in functions - ] + tools = [{"type": "function", "function": function} for function in functions] # Convert legacy function_call to tool_choice if function_call is not None: - if isinstance(function_call, str) and ( - function_call == "none" or function_call == "auto" - ): + if isinstance(function_call, str) and (function_call in ("none", "auto")): tool_choice = function_call if isinstance(function_call, dict) and "name" in function_call: - tool_choice = { - "type": "function", - "function": { - "name": function_call["name"], - }, - } + tool_choice = {"type": "function", "function": {"name": function_call["name"]}} + # Collect the llama.create_completion keyword arguments so we don't have to repeat these with + # each completion call stop = ( [stop, "<|im_end|>"] if isinstance(stop, str) - else stop + ["<|im_end|>"] if stop else ["<|im_end|>"] + else [*stop, "<|im_end|>"] + if stop + else ["<|im_end|>"] + ) + grammar = ( # It is assumed the grammar applies to messages only, not tool calls + grammar + if grammar is not None + else ( + _grammar_for_response_format(response_format) + if response_format is not None and response_format["type"] == "json_object" + else None + ) ) + completion_kwargs = { + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + "min_p": min_p, + "typical_p": typical_p, + "stream": stream, + "stop": stop, + "max_tokens": max_tokens, + "presence_penalty": presence_penalty, + "frequency_penalty": frequency_penalty, + "repeat_penalty": repeat_penalty, + "tfs_z": tfs_z, + "mirostat_mode": mirostat_mode, + "mirostat_tau": mirostat_tau, + "mirostat_eta": mirostat_eta, + "model": model, + "logits_processor": logits_processor, + "grammar": grammar, + } - # Case 1: No tool choice by user + # Case 1: No tool use if ( tool_choice is None or (isinstance(tool_choice, str) and tool_choice == "none") @@ -3501,316 +3799,526 @@ def chatml_function_calling( or len(tools) == 0 ): prompt = template_renderer.render( - messages=messages, - tools=[], - tool_calls=None, - add_generation_prompt=True, + messages=messages, tools=[], tool_calls=None, add_generation_prompt=True ) - - if response_format is not None and response_format["type"] == "json_object": - grammar = _grammar_for_response_format(response_format) - return _convert_completion_to_chat( llama.create_completion( prompt=prompt, - temperature=temperature, - top_p=top_p, - top_k=top_k, - min_p=min_p, - typical_p=typical_p, - stream=stream, - stop=stop, - max_tokens=max_tokens, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - repeat_penalty=repeat_penalty, - tfs_z=tfs_z, - mirostat_mode=mirostat_mode, - mirostat_tau=mirostat_tau, - mirostat_eta=mirostat_eta, - model=model, - logits_processor=logits_processor, - grammar=grammar, + **completion_kwargs, # type: ignore[arg-type] logprobs=top_logprobs if logprobs else None, ), stream=stream, ) - # Case 2: Tool choice by user + # Ensure there is a system prompt to attach the tool metadata to + if not any(message["role"] == "system" for message in messages): + messages = [*messages, {"role": "system", "content": ""}] + + # Case 2: Automatic or fixed tool choice + # Case 2 step 1: Determine whether to respond with a message or a tool call + assert (isinstance(tool_choice, str) and tool_choice == "auto") or isinstance(tool_choice, dict) if isinstance(tool_choice, dict): - tool_name = tool_choice["function"]["name"] - tool = next( - (tool for tool in tools if tool["function"]["name"] == tool_name), None + tools = [t for t in tools if t["function"]["name"] == tool_choice["function"]["name"]] + assert tools + function_names = " | ".join([f'''"functions.{t['function']['name']}:"''' for t in tools]) + prompt = template_renderer.render( + messages=messages, tools=tools, tool_calls=True, add_generation_prompt=True + ) + initial_gbnf_tool_grammar = ( + ( + 'root ::= "" "\\n" functions | "message:"\n' + f"functions ::= {function_names}\n" ) - if tool is None: - raise ValueError(f"Tool with name '{tool_name}' not found in tools") + if tool_choice == "auto" + else f'root ::= "" "\\n" functions\nfunctions ::= {function_names}\n' + ) + completion = cast( + llama_types.CreateCompletionResponse, + llama.create_completion( + prompt=prompt, + **{ # type: ignore[arg-type] + **completion_kwargs, + "temperature": 0, + "stream": False, + "stop": [":"], + "max_tokens": None, + "grammar": llama_grammar.LlamaGrammar.from_string( + initial_gbnf_tool_grammar, verbose=llama.verbose + ), + }, + ), + ) + text = completion["choices"][0]["text"] + tool_name = None if text.startswith("message") else text.split("\n")[-1][len("functions.") :] + + # Case 2 step 2A: Respond with a message + if tool_name is None: prompt = template_renderer.render( - messages=messages, - tools=tools, - tool_calls=True, - add_generation_prompt=True, + messages=messages, tools=[], tool_calls=None, add_generation_prompt=True + ) + return _convert_completion_to_chat( + llama.create_completion( + prompt=prompt, + **completion_kwargs, # type: ignore[arg-type] + logprobs=top_logprobs if logprobs else None, + ), + stream=stream, + ) + + # Case 2 step 2B: One or more function calls + follow_up_gbnf_tool_grammar = ( + 'root ::= functions | "" | "<|im_end|>"\n' + f"functions ::= {function_names}\n" + ) + prompt += "\n" + if stream: + return _stream_tool_calls( + llama, prompt, tools, tool_name, completion_kwargs, follow_up_gbnf_tool_grammar ) + tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None) + completions: List[llama_types.CreateCompletionResponse] = [] + completions_tool_name: List[str] = [] + while tool is not None and len(completions) <= 16: + # Generate the parameter values for the selected tool prompt += f"functions.{tool_name}:\n" try: grammar = llama_grammar.LlamaGrammar.from_json_schema( json.dumps(tool["function"]["parameters"]), verbose=llama.verbose ) except Exception as e: + warnings.warn( + f"Failed to parse function body as JSON schema, falling back to default grammar\n\n{e}", + category=RuntimeWarning, + stacklevel=2, + ) grammar = llama_grammar.LlamaGrammar.from_string( llama_grammar.JSON_GBNF, verbose=llama.verbose ) - if llama.verbose: - print( - "Failed to parse function body as JSON schema, falling back to default grammar" - ) - print(e) completion_or_chunks = llama.create_completion( prompt=prompt, - temperature=temperature, - top_p=top_p, - top_k=top_k, - min_p=min_p, - typical_p=typical_p, - stream=stream, - stop=stop, - max_tokens=max_tokens, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - repeat_penalty=repeat_penalty, - tfs_z=tfs_z, - mirostat_mode=mirostat_mode, - mirostat_tau=mirostat_tau, - mirostat_eta=mirostat_eta, - model=model, - logits_processor=logits_processor, - grammar=grammar, + **{ # type: ignore[arg-type] + **completion_kwargs, + "max_tokens": None, + "grammar": grammar, + }, ) - return _convert_completion_to_chat_function( - tool_name, completion_or_chunks, stream + completion = cast(llama_types.CreateCompletionResponse, completion_or_chunks) + completions.append(completion) + completions_tool_name.append(tool_name) + prompt += completion["choices"][0]["text"] + prompt += "\n" + # Determine whether to call another tool or stop + response = cast( + llama_types.CreateCompletionResponse, + llama.create_completion( + prompt=prompt, + **{ # type: ignore[arg-type] + **completion_kwargs, + "temperature": 0, + "stream": False, + "stop": [*completion_kwargs["stop"], ":", ""], # type: ignore[misc] + "max_tokens": None, + "grammar": llama_grammar.LlamaGrammar.from_string( + follow_up_gbnf_tool_grammar, verbose=llama.verbose + ), + }, + ), ) + tool_name = response["choices"][0]["text"][len("functions.") :] + tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None) + # Merge the completions into a single chat completion + chat_completion: llama_types.CreateChatCompletionResponse = { + "id": "chat" + completion["id"], + "object": "chat.completion", + "created": completion["created"], + "model": completion["model"], + "choices": [ + { + "finish_reason": "tool_calls", + "index": 0, + "logprobs": _convert_text_completion_logprobs_to_chat( + completion["choices"][0]["logprobs"] + ), + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_" + f"_{i}_" + tool_name + "_" + completion["id"], + "type": "function", + "function": { + "name": tool_name, + "arguments": completion["choices"][0]["text"], + }, + } + for i, (tool_name, completion) in enumerate( + zip(completions_tool_name, completions) + ) + ], + }, + } + ], + "usage": { + "completion_tokens": sum( + (completion["usage"]["completion_tokens"] if "usage" in completion else 0) + for completion in completions + ), + "prompt_tokens": sum( + completion["usage"]["prompt_tokens"] if "usage" in completion else 0 + for completion in completions + ), + "total_tokens": sum( + completion["usage"]["total_tokens"] if "usage" in completion else 0 + for completion in completions + ), + }, + } + if len(completions) == 1: + single_function_call: llama_types.ChatCompletionResponseFunctionCall = { + "name": tool_name, + "arguments": completions[0]["choices"][0]["text"], + } + chat_completion["choices"][0]["message"]["function_call"] = single_function_call + return chat_completion + - # Case 3: Automatic tool choice - assert isinstance(tool_choice, str) and tool_choice == "auto" - function_names = " | ".join( - [f'''"functions.{tool['function']['name']}:"''' for tool in tools] +@register_chat_completion_handler("gguf-function-calling") +def gguf_function_calling( + llama: llama.Llama, + messages: List[llama_types.ChatCompletionRequestMessage], + functions: Optional[List[llama_types.ChatCompletionFunction]] = None, + function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None, + tools: Optional[List[llama_types.ChatCompletionTool]] = None, + tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None, + temperature: float = 0.2, + top_p: float = 0.95, + top_k: int = 40, + min_p: float = 0.05, + typical_p: float = 1.0, + stream: bool = False, + stop: Optional[Union[str, List[str]]] = [], + response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None, + max_tokens: Optional[int] = None, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + repeat_penalty: float = 1.1, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + model: Optional[str] = None, + logits_processor: Optional[llama.LogitsProcessorList] = None, + grammar: Optional[llama.LlamaGrammar] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, + **kwargs: Any, +) -> Union[ + llama_types.CreateChatCompletionResponse, + Iterator[llama_types.CreateChatCompletionStreamResponse], +]: + + function_calling_template = None + if hasattr(llama, 'model_path'): + metadata = llama.metadata + if metadata and "tokenizer.chat_template" in metadata: + function_calling_template = metadata["tokenizer.chat_template"] + + + function_calling_template = ( + "{% for message in messages %}" + "<|im_start|>{{ message.role }}\n" + # System message + "{% if message.role == 'system' %}" + "{{ message.content }}" + "{% if tool_calls %}" + "\n\nYou have access to the following functions:\n" + "{% for tool in tools %}" + '\n{% if tool.function.get("description") %}/* {{ tool.function.description | trim }} */{% endif %}' + "\nfunctions.{{ tool.function.name }}:\n" + "{{ tool.function.parameters | tojson }}" + "\n{% endfor %}" + "\nYou must respond to user messages with either a single message or with one or more function calls." + "\n\nTo respond with a message use the following format:" + "\n\nmessage:" + "\n" + "\n\nTo respond with one or more function calls use the following format:" + "\n\n" + "\nfunctions.:" + '\n{ "arg1": "value1", "arg2": "value2" }' + "\nfunctions.:" + '\n{ "arg1": "value1", "arg2": "value2" }' + "\n" + "{% endif %}" + "<|im_end|>\n" + "{% endif %}" + # User message + "{% if message.role == 'user' %}" + "{{ message.content }}" + "<|im_end|>\n" + "{% endif %}" + # Assistant message + "{% if message.role == 'assistant' %}" + ## Regular message + "{% if message.content and message.content | length > 0 %}" + "{% if tool_calls %}" + "message:\n" + "{% endif %}" + "{{ message.content }}" + "<|im_end|>\n" + "{% endif %}" + ## Function calls + "{% if 'tool_calls' in message %}" + "{% for tool_call in message.tool_calls %}" + "functions.{{ tool_call.function.name }}:\n" + "{{ tool_call.function.arguments }}" + "{% endfor %}" + "<|im_end|>\n" + "{% endif %}" + "{% endif %}" + "{% endfor %}" + "{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}" ) - initial_gbnf_tool_grammar = ( - """root ::= functions | "message:"\n""" - f"""functions ::= {function_names}\n""" + template_renderer = ImmutableSandboxedEnvironment( + autoescape=jinja2.select_autoescape(["html", "xml"]), + undefined=jinja2.StrictUndefined, + ).from_string(function_calling_template) + + # Convert legacy functions to tools + if functions is not None: + tools = [{"type": "function", "function": function} for function in functions] + + # Convert legacy function_call to tool_choice + if function_call is not None: + if isinstance(function_call, str) and (function_call in ("none", "auto")): + tool_choice = function_call + if isinstance(function_call, dict) and "name" in function_call: + tool_choice = {"type": "function", "function": {"name": function_call["name"]}} + + # Collect the llama.create_completion keyword arguments so we don't have to repeat these with + # each completion call + stop = ( + [stop, "<|im_end|>"] + if isinstance(stop, str) + else [*stop, "<|im_end|>"] + if stop + else ["<|im_end|>"] ) - follow_up_gbnf_tool_grammar = ( - """root ::= functions | "<|im_end|>"\n""" - f"""functions ::= {function_names}\n""" + grammar = ( # It is assumed the grammar applies to messages only, not tool calls + grammar + if grammar is not None + else ( + _grammar_for_response_format(response_format) + if response_format is not None and response_format["type"] == "json_object" + else None + ) ) + completion_kwargs = { + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + "min_p": min_p, + "typical_p": typical_p, + "stream": stream, + "stop": stop, + "max_tokens": max_tokens, + "presence_penalty": presence_penalty, + "frequency_penalty": frequency_penalty, + "repeat_penalty": repeat_penalty, + "tfs_z": tfs_z, + "mirostat_mode": mirostat_mode, + "mirostat_tau": mirostat_tau, + "mirostat_eta": mirostat_eta, + "model": model, + "logits_processor": logits_processor, + "grammar": grammar, + } + + # Case 1: No tool use + if ( + tool_choice is None + or (isinstance(tool_choice, str) and tool_choice == "none") + or tools is None + or len(tools) == 0 + ): + prompt = template_renderer.render( + messages=messages, tools=[], tool_calls=None, add_generation_prompt=True + ) + return _convert_completion_to_chat( + llama.create_completion( + prompt=prompt, + **completion_kwargs, # type: ignore[arg-type] + logprobs=top_logprobs if logprobs else None, + ), + stream=stream, + ) + + # Ensure there is a system prompt to attach the tool metadata to + if not any(message["role"] == "system" for message in messages): + messages = [*messages, {"role": "system", "content": ""}] + + # Case 2: Automatic or fixed tool choice + # Case 2 step 1: Determine whether to respond with a message or a tool call + assert (isinstance(tool_choice, str) and tool_choice == "auto") or isinstance(tool_choice, dict) + if isinstance(tool_choice, dict): + tools = [t for t in tools if t["function"]["name"] == tool_choice["function"]["name"]] + assert tools + function_names = " | ".join([f'''"functions.{t['function']['name']}:"''' for t in tools]) prompt = template_renderer.render( - messages=messages, - tools=tools, - tool_calls=True, - add_generation_prompt=True, + messages=messages, tools=tools, tool_calls=True, add_generation_prompt=True ) - completion_or_chunks = llama.create_completion( - prompt=prompt, - temperature=0, - top_p=top_p, - top_k=top_k, - min_p=min_p, - typical_p=typical_p, - stream=False, - stop=[":"], - max_tokens=None, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - repeat_penalty=repeat_penalty, - tfs_z=tfs_z, - mirostat_mode=mirostat_mode, - mirostat_tau=mirostat_tau, - mirostat_eta=mirostat_eta, - model=model, - logits_processor=logits_processor, - grammar=llama_grammar.LlamaGrammar.from_string( - initial_gbnf_tool_grammar, verbose=llama.verbose + initial_gbnf_tool_grammar = ( + ( + 'root ::= "" "\\n" functions | "message:"\n' + f"functions ::= {function_names}\n" + ) + if tool_choice == "auto" + else f'root ::= "" "\\n" functions\nfunctions ::= {function_names}\n' + ) + completion = cast( + llama_types.CreateCompletionResponse, + llama.create_completion( + prompt=prompt, + **{ # type: ignore[arg-type] + **completion_kwargs, + "temperature": 0, + "stream": False, + "stop": [":"], + "max_tokens": None, + "grammar": llama_grammar.LlamaGrammar.from_string( + initial_gbnf_tool_grammar, verbose=llama.verbose + ), + }, ), ) - completion: llama_types.CreateCompletionResponse = completion_or_chunks # type: ignore text = completion["choices"][0]["text"] - if "message" in text: + tool_name = None if text.startswith("message") else text.split("\n")[-1][len("functions.") :] + + # Case 2 step 2A: Respond with a message + if tool_name is None: + prompt = template_renderer.render( + messages=messages, tools=[], tool_calls=None, add_generation_prompt=True + ) return _convert_completion_to_chat( llama.create_completion( - prompt=prompt + "message:\n", - temperature=temperature, - top_p=top_p, - top_k=top_k, - min_p=min_p, - typical_p=typical_p, - stream=stream, - stop=["<|im_end|>"], + prompt=prompt, + **completion_kwargs, # type: ignore[arg-type] logprobs=top_logprobs if logprobs else None, - max_tokens=None, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - repeat_penalty=repeat_penalty, - tfs_z=tfs_z, - mirostat_mode=mirostat_mode, - mirostat_tau=mirostat_tau, - mirostat_eta=mirostat_eta, - model=model, - logits_processor=logits_processor, - grammar=llama_grammar.LlamaGrammar.from_string( - follow_up_gbnf_tool_grammar, verbose=llama.verbose - ), ), stream=stream, ) - # One or more function calls - tool_name = text[len("functions.") :] + # Case 2 step 2B: One or more function calls + follow_up_gbnf_tool_grammar = ( + 'root ::= functions | "" | "<|im_end|>"\n' + f"functions ::= {function_names}\n" + ) + prompt += "\n" + if stream: + return _stream_tool_calls( + llama, prompt, tools, tool_name, completion_kwargs, follow_up_gbnf_tool_grammar + ) tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None) - if not stream: - completions: List[llama_types.CreateCompletionResponse] = [] - completions_tool_name: List[str] = [] - while tool is not None: - prompt += f"functions.{tool_name}:\n" - try: - grammar = llama_grammar.LlamaGrammar.from_json_schema( - json.dumps(tool["function"]["parameters"]), verbose=llama.verbose - ) - except Exception as e: - grammar = llama_grammar.LlamaGrammar.from_string( - llama_grammar.JSON_GBNF, verbose=llama.verbose - ) - if llama.verbose: - print( - "Failed to parse function body as JSON schema, falling back to default grammar" - ) - print(e) - completion_or_chunks = llama.create_completion( - prompt=prompt, - temperature=temperature, - top_p=top_p, - top_k=top_k, - min_p=min_p, - typical_p=typical_p, - stream=False, - stop=stop, - max_tokens=None, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - repeat_penalty=repeat_penalty, - tfs_z=tfs_z, - mirostat_mode=mirostat_mode, - mirostat_tau=mirostat_tau, - mirostat_eta=mirostat_eta, - model=model, - logits_processor=logits_processor, - grammar=grammar, - ) - completion_or_chunks = cast( - llama_types.CreateCompletionResponse, completion_or_chunks + completions: List[llama_types.CreateCompletionResponse] = [] + completions_tool_name: List[str] = [] + while tool is not None and len(completions) <= 16: + # Generate the parameter values for the selected tool + prompt += f"functions.{tool_name}:\n" + try: + grammar = llama_grammar.LlamaGrammar.from_json_schema( + json.dumps(tool["function"]["parameters"]), verbose=llama.verbose ) - completions.append(completion_or_chunks) - completions_tool_name.append(tool_name) - prompt += completion_or_chunks["choices"][0]["text"] - prompt += "\n" - - response = llama.create_completion( - prompt=prompt, - temperature=temperature, - top_p=top_p, - top_k=top_k, - min_p=min_p, - typical_p=typical_p, - stream=False, - stop=stop, - max_tokens=None, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - repeat_penalty=repeat_penalty, - tfs_z=tfs_z, - mirostat_mode=mirostat_mode, - mirostat_tau=mirostat_tau, - mirostat_eta=mirostat_eta, - model=model, - logits_processor=logits_processor, - grammar=llama_grammar.LlamaGrammar.from_string( - follow_up_gbnf_tool_grammar, verbose=llama.verbose - ), + except Exception as e: + warnings.warn( + f"Failed to parse function body as JSON schema, falling back to default grammar\n\n{e}", + category=RuntimeWarning, + stacklevel=2, ) - response = cast(llama_types.CreateCompletionResponse, response) - - tool_name = response["choices"][0]["text"][len("functions.") :] - tool = next( - (tool for tool in tools if tool["function"]["name"] == tool_name), None + grammar = llama_grammar.LlamaGrammar.from_string( + llama_grammar.JSON_GBNF, verbose=llama.verbose ) - - # Merge completions - function_call_dict: Union[ - Dict[str, str], - Dict[ - Literal["function_call"], - llama_types.ChatCompletionRequestAssistantMessageFunctionCall, - ], - ] = ( - { - "function_call": { - "name": tool_name, - "arguments": completions[0]["choices"][0]["text"], - } - } - if len(completions) == 1 - else {} + completion_or_chunks = llama.create_completion( + prompt=prompt, + **{ # type: ignore[arg-type] + **completion_kwargs, + "max_tokens": None, + "grammar": grammar, + }, ) - return { - "id": "chat" + completion["id"], - "object": "chat.completion", - "created": completion["created"], - "model": completion["model"], - "choices": [ - { - "finish_reason": "tool_calls", - "index": 0, - "logprobs": _convert_text_completion_logprobs_to_chat(completion["choices"][0]["logprobs"]), - "message": { - "role": "assistant", - "content": None, - "tool_calls": [ - { - "id": "call_" - + f"_{i}_" - + tool_name - + "_" - + completion["id"], - "type": "function", - "function": { - "name": tool_name, - "arguments": completion["choices"][0]["text"], - }, - } - for i, (tool_name, completion) in enumerate( - zip(completions_tool_name, completions) - ) - ], - **function_call_dict, - }, - } - ], - "usage": { - "completion_tokens": sum( - ( - completion["usage"]["completion_tokens"] - if "usage" in completion - else 0 - ) - for completion in completions - ), - "prompt_tokens": sum( - completion["usage"]["prompt_tokens"] if "usage" in completion else 0 - for completion in completions - ), - "total_tokens": sum( - completion["usage"]["total_tokens"] if "usage" in completion else 0 - for completion in completions + completion = cast(llama_types.CreateCompletionResponse, completion_or_chunks) + completions.append(completion) + completions_tool_name.append(tool_name) + prompt += completion["choices"][0]["text"] + prompt += "\n" + # Determine whether to call another tool or stop + response = cast( + llama_types.CreateCompletionResponse, + llama.create_completion( + prompt=prompt, + **{ # type: ignore[arg-type] + **completion_kwargs, + "temperature": 0, + "stream": False, + "stop": [*completion_kwargs["stop"], ":", ""], # type: ignore[misc] + "max_tokens": None, + "grammar": llama_grammar.LlamaGrammar.from_string( + follow_up_gbnf_tool_grammar, verbose=llama.verbose + ), + }, + ), + ) + tool_name = response["choices"][0]["text"][len("functions.") :] + tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None) + # Merge the completions into a single chat completion + chat_completion: llama_types.CreateChatCompletionResponse = { + "id": "chat" + completion["id"], + "object": "chat.completion", + "created": completion["created"], + "model": completion["model"], + "choices": [ + { + "finish_reason": "tool_calls", + "index": 0, + "logprobs": _convert_text_completion_logprobs_to_chat( + completion["choices"][0]["logprobs"] ), - }, + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_" + f"_{i}_" + tool_name + "_" + completion["id"], + "type": "function", + "function": { + "name": tool_name, + "arguments": completion["choices"][0]["text"], + }, + } + for i, (tool_name, completion) in enumerate( + zip(completions_tool_name, completions) + ) + ], + }, + } + ], + "usage": { + "completion_tokens": sum( + (completion["usage"]["completion_tokens"] if "usage" in completion else 0) + for completion in completions + ), + "prompt_tokens": sum( + completion["usage"]["prompt_tokens"] if "usage" in completion else 0 + for completion in completions + ), + "total_tokens": sum( + completion["usage"]["total_tokens"] if "usage" in completion else 0 + for completion in completions + ), + }, + } + if len(completions) == 1: + single_function_call: llama_types.ChatCompletionResponseFunctionCall = { + "name": tool_name, + "arguments": completions[0]["choices"][0]["text"], } - - raise ValueError("Automatic streaming tool choice is not supported") + chat_completion["choices"][0]["message"]["function_call"] = single_function_call + return chat_completion diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index f3985ad..5dc68ee 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -10,7 +10,9 @@ NewType, Optional, TYPE_CHECKING, + List, ) +from enum import IntEnum from llama_cpp._ctypes_extensions import ( load_shared_library, @@ -40,70 +42,48 @@ ctypes_function = ctypes_function_for_shared_library(_lib) -# from ggml.h -# // NOTE: always add types at the end of the enum to keep backward compatibility -# enum ggml_type { -# GGML_TYPE_F32 = 0, -# GGML_TYPE_F16 = 1, -# GGML_TYPE_Q4_0 = 2, -# GGML_TYPE_Q4_1 = 3, -# // GGML_TYPE_Q4_2 = 4, support has been removed -# // GGML_TYPE_Q4_3 = 5, support has been removed -# GGML_TYPE_Q5_0 = 6, -# GGML_TYPE_Q5_1 = 7, -# GGML_TYPE_Q8_0 = 8, -# GGML_TYPE_Q8_1 = 9, -# GGML_TYPE_Q2_K = 10, -# GGML_TYPE_Q3_K = 11, -# GGML_TYPE_Q4_K = 12, -# GGML_TYPE_Q5_K = 13, -# GGML_TYPE_Q6_K = 14, -# GGML_TYPE_Q8_K = 15, -# GGML_TYPE_IQ2_XXS = 16, -# GGML_TYPE_IQ2_XS = 17, -# GGML_TYPE_IQ3_XXS = 18, -# GGML_TYPE_IQ1_S = 19, -# GGML_TYPE_IQ4_NL = 20, -# GGML_TYPE_IQ3_S = 21, -# GGML_TYPE_IQ2_S = 22, -# GGML_TYPE_IQ4_XS = 23, -# GGML_TYPE_I8 = 24, -# GGML_TYPE_I16 = 25, -# GGML_TYPE_I32 = 26, -# GGML_TYPE_I64 = 27, -# GGML_TYPE_F64 = 28, -# GGML_TYPE_IQ1_M = 29, -# GGML_TYPE_COUNT, -# }; -GGML_TYPE_F32 = 0 -GGML_TYPE_F16 = 1 -GGML_TYPE_Q4_0 = 2 -GGML_TYPE_Q4_1 = 3 -GGML_TYPE_Q5_0 = 6 -GGML_TYPE_Q5_1 = 7 -GGML_TYPE_Q8_0 = 8 -GGML_TYPE_Q8_1 = 9 -GGML_TYPE_Q2_K = 10 -GGML_TYPE_Q3_K = 11 -GGML_TYPE_Q4_K = 12 -GGML_TYPE_Q5_K = 13 -GGML_TYPE_Q6_K = 14 -GGML_TYPE_Q8_K = 15 -GGML_TYPE_IQ2_XXS = 16 -GGML_TYPE_IQ2_XS = 17 -GGML_TYPE_IQ3_XXS = 18 -GGML_TYPE_IQ1_S = 19 -GGML_TYPE_IQ4_NL = 20 -GGML_TYPE_IQ3_S = 21 -GGML_TYPE_IQ2_S = 22 -GGML_TYPE_IQ4_XS = 23 -GGML_TYPE_I8 = 24 -GGML_TYPE_I16 = 25 -GGML_TYPE_I32 = 26 -GGML_TYPE_I64 = 27 -GGML_TYPE_F64 = 28 -GGML_TYPE_IQ1_M = 29 -GGML_TYPE_COUNT = 30 +# From ggml.h: enum ggml_type +class GGMLType(IntEnum): + GGML_TYPE_F32 = 0 + GGML_TYPE_F16 = 1 + GGML_TYPE_Q4_0 = 2 + GGML_TYPE_Q4_1 = 3 + # GGML_TYPE_Q4_2 = 4 # support has been removed + # GGML_TYPE_Q4_3 = 5 # support has been removed + GGML_TYPE_Q5_0 = 6 + GGML_TYPE_Q5_1 = 7 + GGML_TYPE_Q8_0 = 8 + GGML_TYPE_Q8_1 = 9 + GGML_TYPE_Q2_K = 10 + GGML_TYPE_Q3_K = 11 + GGML_TYPE_Q4_K = 12 + GGML_TYPE_Q5_K = 13 + GGML_TYPE_Q6_K = 14 + GGML_TYPE_Q8_K = 15 + GGML_TYPE_IQ2_XXS = 16 + GGML_TYPE_IQ2_XS = 17 + GGML_TYPE_IQ3_XXS = 18 + GGML_TYPE_IQ1_S = 19 + GGML_TYPE_IQ4_NL = 20 + GGML_TYPE_IQ3_S = 21 + GGML_TYPE_IQ2_S = 22 + GGML_TYPE_IQ4_XS = 23 + GGML_TYPE_I8 = 24 + GGML_TYPE_I16 = 25 + GGML_TYPE_I32 = 26 + GGML_TYPE_I64 = 27 + GGML_TYPE_F64 = 28 + GGML_TYPE_IQ1_M = 29 + GGML_TYPE_BF16 = 30 + # GGML_TYPE_Q4_0_4_4 = 31 # support has been removed from gguf files + # GGML_TYPE_Q4_0_4_8 = 32 + # GGML_TYPE_Q4_0_8_8 = 33 + GGML_TYPE_TQ1_0 = 34 + GGML_TYPE_TQ2_0 = 35 + # GGML_TYPE_IQ4_NL_4_4 = 36 + # GGML_TYPE_IQ4_NL_4_8 = 37 + # GGML_TYPE_IQ4_NL_8_8 = 38 + GGML_TYPE_COUNT = 39 # from ggml-backend.h # typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data); @@ -165,6 +145,10 @@ # llama_sampler_p = NewType("llama_sampler_p", int) # llama_sampler_p_ctypes = ctypes.c_void_p +# struct llama_kv_cache; +llama_kv_cache_p = NewType("llama_kv_cache_p", int) +llama_kv_cache_p_ctypes = ctypes.c_void_p + # typedef int32_t llama_pos; llama_pos = ctypes.c_int32 # typedef int32_t llama_token; @@ -228,6 +212,11 @@ # LLAMA_VOCAB_PRE_TYPE_MINERVA = 27, # LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, # LLAMA_VOCAB_PRE_TYPE_GPT4O = 29, +# LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30, +# LLAMA_VOCAB_PRE_TYPE_TRILLION = 31, +# LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, +# LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, +# LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, # }; LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0 LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1 @@ -245,7 +234,7 @@ LLAMA_VOCAB_PRE_TYPE_DBRX = 13 LLAMA_VOCAB_PRE_TYPE_SMAUG = 14 LLAMA_VOCAB_PRE_TYPE_PORO = 15 -LLAMA_VOCAV_PRE_TYPE_CHATGLM3 = 16 +LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16 LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17 LLAMA_VOCAB_PRE_TYPE_VIKING = 18 LLAMA_VOCAB_PRE_TYPE_JAIS = 19 @@ -259,6 +248,11 @@ LLAMA_VOCAB_PRE_TYPE_MINERVA = 27 LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28 LLAMA_VOCAB_PRE_TYPE_GPT4O = 29 +LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30 +LLAMA_VOCAB_PRE_TYPE_TRILLION = 31 +LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32 +LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33 +LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34 # // note: these values should be synchronized with ggml_rope @@ -630,10 +624,25 @@ class llama_model_kv_override(ctypes.Structure): value: Union[int, float, bool, bytes] +# struct llama_model_tensor_buft_override { +# const char * pattern; +# ggml_backend_buffer_type_t buft; +# }; + +class llama_model_tensor_buft_override(ctypes.Structure): + _fields_ = [ + ("pattern", ctypes.c_char_p), + ("buft", ctypes.c_int), + ] + + # struct llama_model_params { # // NULL-terminated list of devices to use for offloading (if NULL, all available devices are used) # ggml_backend_dev_t * devices; +# // NULL-terminated list of buffer types to use for tensors that match a pattern +# const struct llama_model_tensor_buft_override * tensor_buft_overrides; + # int32_t n_gpu_layers; // number of layers to store in VRAM # enum llama_split_mode split_mode; // how to split the model across multiple GPUs @@ -668,6 +677,8 @@ class llama_model_params(ctypes.Structure): """Parameters for llama_model Attributes: + devices (ctypes.Array[ggml_backend_dev_t]): NULL-terminated list of devices to use for offloading (if NULL, all available devices are used) + tensor_buft_overrides (ctypes.Array[llama_model_tensor_buft_override]): NULL-terminated list of buffer types to use for tensors that match a pattern n_gpu_layers (int): number of layers to store in VRAM split_mode (int): how to split the model across multiple GPUs main_gpu (int): the GPU that is used for the entire model. main_gpu interpretation depends on split_mode: LLAMA_SPLIT_NONE: the GPU that is used for the entire model LLAMA_SPLIT_ROW: the GPU that is used for small tensors and intermediate results LLAMA_SPLIT_LAYER: ignored @@ -681,6 +692,8 @@ class llama_model_params(ctypes.Structure): check_tensors (bool): validate model tensor data""" if TYPE_CHECKING: + devices: CtypesArray[ctypes.c_void_p] # NOTE: unused + tensor_buft_overrides: CtypesArray[llama_model_tensor_buft_override] # NOTE: unused n_gpu_layers: int split_mode: int main_gpu: int @@ -695,6 +708,7 @@ class llama_model_params(ctypes.Structure): _fields_ = [ ("devices", ctypes.c_void_p), # NOTE: unnused + ("tensor_buft_overrides", ctypes.c_void_p), # NOTE: unused ("n_gpu_layers", ctypes.c_int32), ("split_mode", ctypes.c_int), ("main_gpu", ctypes.c_int32), @@ -754,7 +768,8 @@ class llama_model_params(ctypes.Structure): # void * abort_callback_data; # }; class llama_context_params(ctypes.Structure): - """Parameters for llama_context + """Parameters for llama_context. NOTE: changing the default values of parameters marked as [EXPERIMENTAL] + may cause crashes or incorrect results in certain configurations. Attributes: n_ctx (int): text context, 0 = from model @@ -764,7 +779,7 @@ class llama_context_params(ctypes.Structure): n_threads (int): number of threads to use for generation n_threads_batch (int): number of threads to use for batch processing rope_scaling_type (int): RoPE scaling type, from `enum llama_rope_scaling_type` - pooling_type (int): whether to pool (sum) embedding results by sequence id (ignored if no pooling layer) + pooling_type (int): whether to pool (sum) embedding results by sequence id attention_type (int): attention type to use for embeddings rope_freq_base (float): RoPE base frequency, 0 = from model rope_freq_scale (float): RoPE frequency scaling factor, 0 = from model @@ -773,18 +788,19 @@ class llama_context_params(ctypes.Structure): yarn_beta_fast (float): YaRN low correction dim yarn_beta_slow (float): YaRN high correction dim yarn_orig_ctx (int): YaRN original context size - defrag_thold (float): defragment the KV cache if holes/size > thold, < 0 disabled (default) + defrag_thold (float): defragment the KV cache if holes/size > thold, <= 0 disabled (default) cb_eval (ggml_backend_sched_eval_callback): callback for scheduling eval - cb_eval_user_data (ctypes.ctypes.c_void_p): user data for cb_eval - type_k (int): data type for K cache - type_v (int): data type for V cache - logits_all (bool): the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) - embeddings (bool): if true, extract embeddings (together with logits) - offload_kqv (bool): whether to offload the KQV ops (including the KV cache) to GPU - flash_attn (bool): whether to use flash attention - no_perf (bool): whether to measure performance timings + cb_eval_user_data (ctypes.c_void_p): user data for cb_eval + type_k (GGMLType): data type for K cache [EXPERIMENTAL] + type_v (GGMLType): data type for V cache [EXPERIMENTAL] abort_callback (ggml_abort_callback): abort callback if it returns true, execution of llama_decode() will be aborted - abort_callback_data (ctypes.ctypes.c_void_p): data for abort_callback + abort_callback_data (ctypes.c_void_p): data for abort_callback + embeddings (bool): if true, extract embeddings (together with logits) + offload_kqv (bool): offload the KQV ops (including the KV cache) to GPU + flash_attn (bool): use flash attention [EXPERIMENTAL] + no_perf (bool): measure performance timings + op_offload (bool): offload host tensor operations to device + swa_full (bool): use full-size SWA cache """ if TYPE_CHECKING: @@ -807,15 +823,16 @@ class llama_context_params(ctypes.Structure): defrag_thold: float cb_eval: Callable[[ctypes.c_void_p, bool], bool] cb_eval_user_data: ctypes.c_void_p - type_k: int - type_v: int - logits_all: bool + type_k: GGMLType + type_v: GGMLType + abort_callback: Callable[[ctypes.c_void_p], bool] + abort_callback_data: ctypes.c_void_p embeddings: bool offload_kqv: bool flash_attn: bool no_perf: bool - abort_callback: Callable[[ctypes.c_void_p], bool] - abort_callback_data: ctypes.c_void_p + op_offload: bool + swa_full: bool _fields_ = [ ("n_ctx", ctypes.c_uint32), @@ -839,13 +856,15 @@ class llama_context_params(ctypes.Structure): ("cb_eval_user_data", ctypes.c_void_p), ("type_k", ctypes.c_int), ("type_v", ctypes.c_int), - ("logits_all", ctypes.c_bool), + ("abort_callback", ggml_abort_callback), + ("abort_callback_data", ctypes.c_void_p), + # Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value ("embeddings", ctypes.c_bool), ("offload_kqv", ctypes.c_bool), ("flash_attn", ctypes.c_bool), ("no_perf", ctypes.c_bool), - ("abort_callback", ggml_abort_callback), - ("abort_callback_data", ctypes.c_void_p), + ("op_offload", ctypes.c_bool), + ("swa_full", ctypes.c_bool), ] @@ -867,17 +886,18 @@ class llama_context_params(ctypes.Structure): # // model quantization parameters # typedef struct llama_model_quantize_params { -# int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() -# enum llama_ftype ftype; // quantize to this llama_ftype -# enum ggml_type output_tensor_type; // output tensor type -# enum ggml_type token_embedding_type; // token embeddings tensor type -# bool allow_requantize; // allow quantizing non-f32/f16 tensors -# bool quantize_output_tensor; // quantize output.weight -# bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored -# bool pure; // quantize all tensors to the default type -# bool keep_split; // quantize to the same number of shards -# void * imatrix; // pointer to importance matrix data -# void * kv_overrides; // pointer to vector containing overrides +# int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() +# enum llama_ftype ftype; // quantize to this llama_ftype +# enum ggml_type output_tensor_type; // output tensor type +# enum ggml_type token_embedding_type; // token embeddings tensor type +# bool allow_requantize; // allow quantizing non-f32/f16 tensors +# bool quantize_output_tensor; // quantize output.weight +# bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored +# bool pure; // quantize all tensors to the default type +# bool keep_split; // quantize to the same number of shards +# void * imatrix; // pointer to importance matrix data +# void * kv_overrides; // pointer to vector containing overrides +# void * tensor_types; // pointer to vector containing tensor types # } llama_model_quantize_params; class llama_model_quantize_params(ctypes.Structure): """Parameters for llama_model_quantize @@ -894,6 +914,8 @@ class llama_model_quantize_params(ctypes.Structure): keep_split (bool): quantize to the same number of shards imatrix (ctypes.c_void_p): pointer to importance matrix data kv_overrides (ctypes.c_void_p): pointer to vector containing overrides + tensor_types (ctypes.c_void_p): pointer to vector containing tensor types + prune_layers (ctypes.c_void_p): pointer to vector containing layer indices to prune """ if TYPE_CHECKING: @@ -908,6 +930,8 @@ class llama_model_quantize_params(ctypes.Structure): keep_split: bool imatrix: ctypes.c_void_p kv_overrides: ctypes.c_void_p + tensor_types: ctypes.c_void_p + prune_layers: ctypes.c_void_p _fields_ = [ ("nthread", ctypes.c_int32), @@ -921,6 +945,8 @@ class llama_model_quantize_params(ctypes.Structure): ("keep_split", ctypes.c_bool), ("imatrix", ctypes.c_void_p), ("kv_overrides", ctypes.c_void_p), + ("tensor_types", ctypes.c_void_p), + ("prune_layers", ctypes.c_void_p), ] @@ -1317,7 +1343,18 @@ def llama_get_model(ctx: llama_context_p, /) -> Optional[llama_model_p]: ... -# LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); +# LLAMA_API struct llama_kv_cache * llama_get_kv_self ( struct llama_context * ctx); +@ctypes_function( + "llama_get_kv_self", + [llama_context_p_ctypes], + llama_kv_cache_p_ctypes, +) +def llama_get_kv_self(ctx: llama_context_p, /) -> Optional[llama_kv_cache_p]: + """Get the KV cache for self-attention""" + ... + + +# LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); @ctypes_function("llama_pooling_type", [llama_context_p_ctypes], ctypes.c_int) def llama_pooling_type(ctx: llama_context_p, /) -> int: ... @@ -1352,6 +1389,11 @@ def llama_model_n_embd(model: llama_model_p, /) -> int: def llama_model_n_layer(model: llama_model_p, /) -> int: ... +# LLAMA_API int32_t llama_model_dev_layer (const struct llama_model * model, int32_t il); +@ctypes_function("llama_model_dev_layer", [llama_model_p_ctypes, ctypes.c_int32], ctypes.c_int32) +def llama_model_dev_layer(model: llama_model_p, il: Union[ctypes.c_int32, int], /) -> int: + ... + # LLAMA_API int32_t llama_model_n_head (const struct llama_model * model); @ctypes_function("llama_model_n_head", [llama_model_p_ctypes], ctypes.c_int32) @@ -1694,164 +1736,55 @@ def llama_apply_adapter_cvec( ... -# // -# // KV cache -# // - - -# // Information associated with an individual cell in the KV cache view. -# struct llama_kv_cache_view_cell { -# // The position for this cell. Takes KV cache shifts into account. -# // May be negative if the cell is not populated. -# llama_pos pos; -# }; -class llama_kv_cache_view_cell(ctypes.Structure): - """Information associated with an individual cell in the KV cache view. - - Attributes: - pos (llama_pos): The position for this cell. Takes KV cache shifts into account. - May be negative if the cell is not populated.""" - - if TYPE_CHECKING: - pos: llama_pos - - _fields_ = [("pos", llama_pos)] - - -# // An updateable view of the KV cache. -# struct llama_kv_cache_view { -# // Number of KV cache cells. This will be the same as the context size. -# int32_t n_cells; - -# // Maximum number of sequences that can exist in a cell. It's not an error -# // if there are more sequences in a cell than this value, however they will -# // not be visible in the view cells_sequences. -# int32_t n_seq_max; - -# // Number of tokens in the cache. For example, if there are two populated -# // cells, the first with 1 sequence id in it and the second with 2 sequence -# // ids then you'll have 3 tokens. -# int32_t token_count; - -# // Number of populated cache cells. -# int32_t used_cells; - -# // Maximum contiguous empty slots in the cache. -# int32_t max_contiguous; - -# // Index to the start of the max_contiguous slot range. Can be negative -# // when cache is full. -# int32_t max_contiguous_idx; - -# // Information for an individual cell. -# struct llama_kv_cache_view_cell * cells; - - -# // The sequences for each cell. There will be n_seq_max items per cell. -# llama_seq_id * cells_sequences; -# }; -class llama_kv_cache_view(ctypes.Structure): - if TYPE_CHECKING: - n_cells: int - n_max_seq: int - token_count: int - used_cells: int - max_contiguous: int - max_contiguous_idx: int - cells: CtypesArray[llama_kv_cache_view_cell] - cells_sequences: CtypesArray[llama_seq_id] - - _fields_ = [ - ("n_cells", ctypes.c_int32), - ("n_max_seq", ctypes.c_int32), - ("token_count", ctypes.c_int32), - ("used_cells", ctypes.c_int32), - ("max_contiguous", ctypes.c_int32), - ("max_contiguous_idx", ctypes.c_int32), - ("cells", ctypes.POINTER(llama_kv_cache_view_cell)), - ("cells_sequences", ctypes.POINTER(llama_seq_id)), - ] - - -llama_kv_cache_view_p = ctypes.POINTER(llama_kv_cache_view) - - -# // Create an empty KV cache view. (use only for debugging purposes) -# LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max); -@ctypes_function( - "llama_kv_cache_view_init", - [llama_context_p_ctypes, ctypes.c_int32], - llama_kv_cache_view, -) -def llama_kv_cache_view_init( - ctx: llama_context_p, n_seq_max: Union[ctypes.c_int32, int], / -) -> llama_kv_cache_view: - """Create an empty KV cache view. (use only for debugging purposes)""" - ... - - -# // Free a KV cache view. (use only for debugging purposes) -# LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view); -@ctypes_function("llama_kv_cache_view_free", [llama_kv_cache_view_p], None) -def llama_kv_cache_view_free(view: "ctypes.pointer[llama_kv_cache_view]", /): # type: ignore - """Free a KV cache view. (use only for debugging purposes)""" - ... - - -# // Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes) -# LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view); -@ctypes_function( - "llama_kv_cache_view_update", [llama_context_p_ctypes, llama_kv_cache_view_p], None -) -def llama_kv_cache_view_update(ctx: llama_context_p, view: CtypesPointerOrRef[llama_kv_cache_view], /): # type: ignore - """Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)""" - ... - - # // Returns the number of tokens in the KV cache (slow, use only for debug) # // If a KV cell has multiple sequences assigned to it, it will be counted multiple times -# LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx); +# LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx); @ctypes_function( - "llama_get_kv_cache_token_count", [llama_context_p_ctypes], ctypes.c_int32 + "llama_kv_self_n_tokens", [llama_context_p_ctypes], ctypes.c_int32 ) -def llama_get_kv_cache_token_count(ctx: llama_context_p, /) -> int: +def llama_kv_self_n_tokens(ctx: llama_context_p, /) -> int: """Returns the number of tokens in the KV cache (slow, use only for debug) If a KV cell has multiple sequences assigned to it, it will be counted multiple times """ ... + # // Returns the number of used KV cells (i.e. have at least one sequence assigned to them) -# LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx); +# LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx); @ctypes_function( - "llama_get_kv_cache_used_cells", [llama_context_p_ctypes], ctypes.c_int32 + "llama_kv_self_used_cells", [llama_context_p_ctypes], ctypes.c_int32 ) -def llama_get_kv_cache_used_cells(ctx: llama_context_p, /) -> int: +def llama_kv_self_used_cells(ctx: llama_context_p, /) -> int: """Returns the number of used KV cells (i.e. have at least one sequence assigned to them)""" ... # // Clear the KV cache - both cell info is erased and KV data is zeroed -# LLAMA_API void llama_kv_cache_clear( +# LLAMA_API void llama_kv_self_clear( # struct llama_context * ctx); -@ctypes_function("llama_kv_cache_clear", [llama_context_p_ctypes], None) -def llama_kv_cache_clear(ctx: llama_context_p, /): - """Clear the KV cache""" +@ctypes_function( + "llama_kv_self_clear", [llama_context_p_ctypes], None +) +def llama_kv_self_clear(ctx: llama_context_p, /): + """Clear the KV cache - both cell info is erased and KV data is zeroed""" ... + + # // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) # // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails # // seq_id < 0 : match any sequence # // p0 < 0 : [0, p1] # // p1 < 0 : [p0, inf) -# LLAMA_API bool llama_kv_cache_seq_rm( +# LLAMA_API bool llama_kv_self_seq_rm( # struct llama_context * ctx, # llama_seq_id seq_id, # llama_pos p0, # llama_pos p1); @ctypes_function( - "llama_kv_cache_seq_rm", + "llama_kv_self_seq_rm", [ llama_context_p_ctypes, llama_seq_id, @@ -1860,7 +1793,7 @@ def llama_kv_cache_clear(ctx: llama_context_p, /): ], ctypes.c_bool, ) -def llama_kv_cache_seq_rm( +def llama_kv_self_seq_rm( ctx: llama_context_p, seq_id: Union[llama_seq_id, int], p0: Union[llama_pos, int], @@ -1881,14 +1814,14 @@ def llama_kv_cache_seq_rm( # // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence # // p0 < 0 : [0, p1] # // p1 < 0 : [p0, inf) -# LLAMA_API void llama_kv_cache_seq_cp( +# LLAMA_API void llama_kv_self_seq_cp( # struct llama_context * ctx, # llama_seq_id seq_id_src, # llama_seq_id seq_id_dst, # llama_pos p0, # llama_pos p1); @ctypes_function( - "llama_kv_cache_seq_cp", + "llama_kv_self_seq_cp", [ llama_context_p_ctypes, llama_seq_id, @@ -1898,7 +1831,7 @@ def llama_kv_cache_seq_rm( ], None, ) -def llama_kv_cache_seq_cp( +def llama_kv_self_seq_cp( ctx: llama_context_p, seq_id_src: Union[llama_seq_id, int], seq_id_dst: Union[llama_seq_id, int], @@ -1914,31 +1847,32 @@ def llama_kv_cache_seq_cp( # // Removes all tokens that do not belong to the specified sequence -# LLAMA_API void llama_kv_cache_seq_keep( +# LLAMA_API void llama_kv_self_seq_keep( # struct llama_context * ctx, # llama_seq_id seq_id); @ctypes_function( - "llama_kv_cache_seq_keep", [llama_context_p_ctypes, llama_seq_id], None + "llama_kv_self_seq_keep", [llama_context_p_ctypes, llama_seq_id], None ) -def llama_kv_cache_seq_keep(ctx: llama_context_p, seq_id: Union[llama_seq_id, int], /): +def llama_kv_self_seq_keep(ctx: llama_context_p, seq_id: Union[llama_seq_id, int], /): """Removes all tokens that do not belong to the specified sequence""" ... + # // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) # // If the KV cache is RoPEd, the KV data is updated accordingly: # // - lazily on next llama_decode() -# // - explicitly with llama_kv_cache_update() +# // - explicitly with llama_kv_self_update() # // p0 < 0 : [0, p1] # // p1 < 0 : [p0, inf) -# LLAMA_API void llama_kv_cache_seq_add( +# LLAMA_API void llama_kv_self_seq_add( # struct llama_context * ctx, # llama_seq_id seq_id, # llama_pos p0, # llama_pos p1, # llama_pos delta); @ctypes_function( - "llama_kv_cache_seq_add", + "llama_kv_self_seq_add", [ llama_context_p_ctypes, llama_seq_id, @@ -1948,7 +1882,7 @@ def llama_kv_cache_seq_keep(ctx: llama_context_p, seq_id: Union[llama_seq_id, in ], None, ) -def llama_kv_cache_seq_add( +def llama_kv_self_seq_add( ctx: llama_context_p, seq_id: Union[llama_seq_id, int], p0: Union[llama_pos, int], @@ -1959,7 +1893,7 @@ def llama_kv_cache_seq_add( """Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) If the KV cache is RoPEd, the KV data is updated accordingly: - lazily on next llama_decode() - - explicitly with llama_kv_cache_update() + - explicitly with llama_kv_self_update() p0 < 0 : [0, p1] p1 < 0 : [p0, inf)""" ... @@ -1969,14 +1903,14 @@ def llama_kv_cache_seq_add( # // If the KV cache is RoPEd, the KV data is updated accordingly # // p0 < 0 : [0, p1] # // p1 < 0 : [p0, inf) -# LLAMA_API void llama_kv_cache_seq_div( +# LLAMA_API void llama_kv_self_seq_div( # struct llama_context * ctx, # llama_seq_id seq_id, # llama_pos p0, # llama_pos p1, # int d); @ctypes_function( - "llama_kv_cache_seq_div", + "llama_kv_self_seq_div", [ llama_context_p_ctypes, llama_seq_id, @@ -1986,7 +1920,7 @@ def llama_kv_cache_seq_add( ], None, ) -def llama_kv_cache_seq_div( +def llama_kv_self_seq_div( ctx: llama_context_p, seq_id: Union[llama_seq_id, int], p0: Union[llama_pos, int], @@ -2001,32 +1935,45 @@ def llama_kv_cache_seq_div( ... +# // Returns the largest position present in the KV cache for the specified sequence +# LLAMA_API llama_pos llama_kv_self_seq_pos_max( +# struct llama_context * ctx, +# llama_seq_id seq_id); +@ctypes_function( + "llama_kv_self_seq_pos_max", [llama_context_p_ctypes, llama_seq_id], llama_pos +) +def llama_kv_self_seq_pos_max( + ctx: llama_context_p, seq_id: Union[llama_seq_id, int], / +) -> int: + """Returns the largest position present in the KV cache for the specified sequence""" + ... + + # // Defragment the KV cache # // This will be applied: # // - lazily on next llama_decode() -# // - explicitly with llama_kv_cache_update() -# LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx); -@ctypes_function("llama_kv_cache_defrag", [llama_context_p_ctypes], None) -def llama_kv_cache_defrag(ctx: llama_context_p, /): +# // - explicitly with llama_kv_self_update() +# LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx); +@ctypes_function("llama_kv_self_defrag", [llama_context_p_ctypes], None) +def llama_kv_self_defrag(ctx: llama_context_p, /): """Defragment the KV cache This will be applied: - lazily on next llama_decode() - - explicitly with llama_kv_cache_update()""" + - explicitly with llama_kv_self_update()""" ... # // Apply the KV cache updates (such as K-shifts, defragmentation, etc.) -# LLAMA_API void llama_kv_cache_update(struct llama_context * ctx); -@ctypes_function("llama_kv_cache_update", [llama_context_p_ctypes], None) -def llama_kv_cache_update(ctx: llama_context_p, /): +# LLAMA_API void llama_kv_self_update(struct llama_context * ctx); +@ctypes_function("llama_kv_self_update", [llama_context_p_ctypes], None) +def llama_kv_self_update(ctx: llama_context_p, /): """Apply the KV cache updates (such as K-shifts, defragmentation, etc.)""" ... - # // Check if the context supports KV cache shifting -# LLAMA_API bool llama_kv_cache_can_shift(struct llama_context * ctx); -@ctypes_function("llama_kv_cache_can_shift", [llama_context_p_ctypes], ctypes.c_bool) -def llama_kv_cache_can_shift(ctx: llama_context_p, /) -> bool: +# LLAMA_API bool llama_kv_self_can_shift(struct llama_context * ctx); +@ctypes_function("llama_kv_self_can_shift", [llama_context_p_ctypes], ctypes.c_bool) +def llama_kv_self_can_shift(ctx: llama_context_p, /) -> bool: """Check if the context supports KV cache shifting""" ... @@ -2547,6 +2494,16 @@ def llama_set_causal_attn(ctx: llama_context_p, causal_attn: bool, /): ... +# // Set whether the model is in warmup mode or not +# // If true, all model tensors are activated during llama_decode() to load and cache their weights. +# LLAMA_API void llama_set_warmup(struct llama_context * ctx, bool warmup); +@ctypes_function("llama_set_warmup", [llama_context_p_ctypes, ctypes.c_bool], None) +def llama_set_warmup(ctx: llama_context_p, warmup: bool, /): + """Set whether the model is in warmup mode or not + If true, all model tensors are activated during llama_decode() to load and cache their weights.""" + ... + + # // Set abort callback # LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data); @ctypes_function( @@ -3244,7 +3201,7 @@ def llama_detokenize( # /// Apply chat template. Inspired by hf apply_chat_template() on python. # /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model" # /// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template -# /// @param tmpl A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead. +# /// @param tmpl A Jinja template to use for this chat. If this is nullptr, the model's default chat template will be used instead. # /// @param chat Pointer to a list of multiple llama_chat_message # /// @param n_msg Number of llama_chat_message in this chat # /// @param add_ass Whether to end the prompt with the token(s) that indicate the start of an assistant message. @@ -3576,6 +3533,7 @@ def llama_sampler_init_softmax() -> llama_sampler_p: # /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 +# /// Setting k <= 0 makes this a noop # LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k); @ctypes_function("llama_sampler_init_top_k", [ctypes.c_int32], llama_sampler_p_ctypes) def llama_sampler_init_top_k(k: int) -> llama_sampler_p: @@ -3701,6 +3659,10 @@ def llama_sampler_init_mirostat_v2( ... +# /// @details Intializes a GBNF grammar, see grammars/README.md for details. +# /// @param vocab The vocabulary that this grammar will be used with. +# /// @param grammar_str The production rules for the grammar, encoded as a string. Returns an empty grammar if empty. Returns NULL if parsing of grammar_str fails. +# /// @param grammar_root The name of the start symbol for the grammar. # LLAMA_API struct llama_sampler * llama_sampler_init_grammar( # const struct llama_vocab * vocab, # const char * grammar_str, diff --git a/llama_cpp/llama_types.py b/llama_cpp/llama_types.py index f647822..67772b8 100644 --- a/llama_cpp/llama_types.py +++ b/llama_cpp/llama_types.py @@ -154,13 +154,13 @@ class ChatCompletionStreamResponseChoice(TypedDict): finish_reason: Optional[Literal["stop", "length", "tool_calls", "function_call"]] logprobs: NotRequired[Optional[ChatCompletionLogprobs]] - class CreateChatCompletionStreamResponse(TypedDict): id: str model: str object: Literal["chat.completion.chunk"] created: int choices: List[ChatCompletionStreamResponseChoice] + usage: NotRequired[CompletionUsage] class ChatCompletionFunctions(TypedDict): diff --git a/llama_cpp/llava_cpp.py b/llama_cpp/llava_cpp.py deleted file mode 100644 index d9dfaf5..0000000 --- a/llama_cpp/llava_cpp.py +++ /dev/null @@ -1,158 +0,0 @@ -from __future__ import annotations - -import os -from ctypes import ( - c_bool, - c_char_p, - c_int, - c_uint8, - c_float, - c_void_p, - POINTER, - _Pointer, # type: ignore - Structure, -) -import pathlib -from typing import ( - Union, - NewType, - Optional, - TYPE_CHECKING, -) - -import llama_cpp.llama_cpp as llama_cpp - -from llama_cpp._ctypes_extensions import ( - load_shared_library, - ctypes_function_for_shared_library, -) - -if TYPE_CHECKING: - from llama_cpp._ctypes_extensions import ( - CtypesArray, - ) - - -# Specify the base name of the shared library to load -_libllava_base_name = "llava" -_libllava_override_path = os.environ.get("LLAVA_CPP_LIB") -_libllava_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib" if _libllava_override_path is None else pathlib.Path() - -# Load the library -_libllava = load_shared_library(_libllava_base_name, _libllava_base_path) - -ctypes_function = ctypes_function_for_shared_library(_libllava) - - -################################################ -# llava.h -################################################ - -# struct clip_ctx; -clip_ctx_p = NewType("clip_ctx_p", int) -clip_ctx_p_ctypes = c_void_p - - -# struct llava_image_embed { -# float * embed; -# int n_image_pos; -# }; -class llava_image_embed(Structure): - _fields_ = [ - ("embed", POINTER(c_float)), - ("n_image_pos", c_int), - ] - - -# /** sanity check for clip <-> llava embed size match */ -# LLAVA_API bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip); -@ctypes_function( - "llava_validate_embed_size", - [llama_cpp.llama_context_p_ctypes, clip_ctx_p_ctypes], - c_bool, -) -def llava_validate_embed_size( - ctx_llama: llama_cpp.llama_context_p, ctx_clip: clip_ctx_p, / -) -> bool: - ... - - -# /** build an image embed from image file bytes */ -# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length); -@ctypes_function( - "llava_image_embed_make_with_bytes", - [clip_ctx_p_ctypes, c_int, POINTER(c_uint8), c_int], - POINTER(llava_image_embed), -) -def llava_image_embed_make_with_bytes( - ctx_clip: clip_ctx_p, - n_threads: Union[c_int, int], - image_bytes: CtypesArray[c_uint8], - image_bytes_length: Union[c_int, int], - /, -) -> "_Pointer[llava_image_embed]": - ... - - -# /** build an image embed from a path to an image filename */ -# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path); -@ctypes_function( - "llava_image_embed_make_with_filename", - [clip_ctx_p_ctypes, c_int, c_char_p], - POINTER(llava_image_embed), -) -def llava_image_embed_make_with_filename( - ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_path: bytes, / -) -> "_Pointer[llava_image_embed]": - ... - - -# LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed); -# /** free an embedding made with llava_image_embed_make_* */ -@ctypes_function("llava_image_embed_free", [POINTER(llava_image_embed)], None) -def llava_image_embed_free(embed: "_Pointer[llava_image_embed]", /): - ... - - -# /** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */ -# LLAVA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past); -@ctypes_function( - "llava_eval_image_embed", - [ - llama_cpp.llama_context_p_ctypes, - POINTER(llava_image_embed), - c_int, - POINTER(c_int), - ], - c_bool, -) -def llava_eval_image_embed( - ctx_llama: llama_cpp.llama_context_p, - embed: "_Pointer[llava_image_embed]", - n_batch: Union[c_int, int], - n_past: "_Pointer[c_int]", - /, -) -> bool: - ... - - -################################################ -# clip.h -################################################ - - -# /** load mmproj model */ -# CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity); -@ctypes_function("clip_model_load", [c_char_p, c_int], clip_ctx_p_ctypes) -def clip_model_load( - fname: bytes, verbosity: Union[c_int, int], / -) -> Optional[clip_ctx_p]: - ... - - -# /** free mmproj model */ -# CLIP_API void clip_free(struct clip_ctx * ctx); -@ctypes_function("clip_free", [clip_ctx_p_ctypes], None) -def clip_free(ctx: clip_ctx_p, /): - ... - diff --git a/llama_cpp/mtmd_cpp.py b/llama_cpp/mtmd_cpp.py new file mode 100644 index 0000000..9eb2ef4 --- /dev/null +++ b/llama_cpp/mtmd_cpp.py @@ -0,0 +1,649 @@ +from __future__ import annotations + +import os +import ctypes +from ctypes import ( + c_bool, + c_char_p, + c_int, + c_int32, + c_uint8, + c_uint32, + c_float, + c_size_t, + c_void_p, + POINTER, + _Pointer, # type: ignore + Structure, +) +import pathlib +from typing import ( + List, + Union, + NewType, + Optional, + TYPE_CHECKING, +) + +import llama_cpp.llama_cpp as llama_cpp + +from llama_cpp._ctypes_extensions import ( + load_shared_library, + ctypes_function_for_shared_library, +) + +if TYPE_CHECKING: + from llama_cpp.llama_types import ( + llama_token, + llama_pos, + ) + from llama_cpp._ctypes_extensions import ( + CtypesArray, + CtypesPointer, + ) + +# Define input text structure +class mtmd_input_text(Structure): + _fields_ = [ + ("text", c_char_p), + ("add_special", c_bool), + ("parse_special", c_bool), + ] + +# Define context parameters structure +class mtmd_context_params(Structure): + _fields_ = [ + ("use_gpu", c_bool), + ("print_timings", c_bool), + ("n_threads", c_int), + ("verbosity", c_int), + ("image_marker", c_char_p), # const char* + ("media_marker", c_char_p), # const char* + ] + +# Define input chunk type enum +mtmd_input_chunk_type = c_int +( + MTMD_INPUT_CHUNK_TYPE_TEXT, + MTMD_INPUT_CHUNK_TYPE_IMAGE, + MTMD_INPUT_CHUNK_TYPE_AUDIO, +) = (0, 1, 2) + +# Define slice template enum +mtmd_slice_tmpl = c_int +( + MTMD_SLICE_TMPL_NONE, + MTMD_SLICE_TMPL_MINICPMV_2_5, + MTMD_SLICE_TMPL_MINICPMV_2_6, + MTMD_SLICE_TMPL_LLAMA4, +) = (0, 1, 2, 3) + +# Define whisper filters structure +class whisper_filters(Structure): + _fields_ = [ + ("n_mel", c_int), + ] + +# Define mtmd_context structure +class mtmd_context(Structure): + _fields_ = [ + ("ctx_v", c_void_p), # clip_ctx* + ("ctx_a", c_void_p), # clip_ctx* + ("text_model", c_void_p), # const llama_model* + ("image_embd_v", POINTER(c_float)), # std::vector + ("print_timings", c_bool), + ("n_threads", c_int), + ("media_marker", c_char_p), # std::string + ("n_embd_text", c_int), + ("img_beg", c_char_p), # std::string + ("img_end", c_char_p), # std::string + ("aud_beg", c_char_p), # std::string + ("aud_end", c_char_p), # std::string + ("slice_tmpl", c_int), # mtmd_slice_tmpl + ("tok_ov_img_start", llama_cpp.llama_token), + ("tok_ov_img_end", llama_cpp.llama_token), + ("tok_slices_start", llama_cpp.llama_token), + ("tok_slices_end", llama_cpp.llama_token), + ("tok_sli_img_start", llama_cpp.llama_token), + ("tok_sli_img_end", llama_cpp.llama_token), + ("tok_sli_img_mid", llama_cpp.llama_token), + ("tok_row_end", llama_cpp.llama_token), + ("tok_row_end_trail", c_bool), + ("ov_img_first", c_bool), + ("use_mrope", c_bool), + ("w_filters", whisper_filters), + ] + +# Define bitmap structure +class mtmd_bitmap(Structure): + _fields_ = [ + ("nx", c_uint32), + ("ny", c_uint32), + ("data", POINTER(c_uint8)), # Vector represented as pointer + ("id", c_char_p), + ("is_audio", c_bool), + ] + +# Define image tokens structure +class mtmd_image_tokens(Structure): + _fields_ = [ + ("nx", c_uint32), + ("ny", c_uint32), + ("use_mrope_pos", c_bool), + ("batch_f32", c_void_p), # clip_image_f32_batch + ("id", c_char_p), + ] + +# Define audio tokens structure +class mtmd_audio_tokens(Structure): + _fields_ = [ + ("n_tokens", c_uint32), + ("batch_f32", c_void_p), # clip_image_f32_batch + ("id", c_char_p), + ] + +# Define input chunk structure +class mtmd_input_chunk(Structure): + _fields_ = [ + ("type", mtmd_input_chunk_type), + ("tokens_text", POINTER(llama_cpp.llama_token)), # Vector represented as pointer + ("tokens_image", c_void_p), # mtmd_image_tokens_ptr + ("tokens_audio", c_void_p), # mtmd_audio_tokens_ptr + ] + +# Define input chunks structure +class mtmd_input_chunks(Structure): + _fields_ = [ + ("entries", POINTER(mtmd_input_chunk)), # Vector represented as pointer + ] + +# Define context pointer type +mtmd_context_p = NewType("mtmd_context_p", int) +mtmd_context_p_ctypes = c_void_p + +# Define bitmap pointer type +mtmd_bitmap_p = NewType("mtmd_bitmap_p", int) +mtmd_bitmap_p_ctypes = c_void_p + +# Define input chunks pointer type +mtmd_input_chunks_p = NewType("mtmd_input_chunks_p", int) +mtmd_input_chunks_p_ctypes = c_void_p + +# Define input chunk pointer type +mtmd_input_chunk_p = NewType("mtmd_input_chunk_p", int) +mtmd_input_chunk_p_ctypes = c_void_p + +# Define image tokens pointer type +mtmd_image_tokens_p = NewType("mtmd_image_tokens_p", int) +mtmd_image_tokens_p_ctypes = c_void_p + +# Define audio tokens pointer type +mtmd_audio_tokens_p = NewType("mtmd_audio_tokens_p", int) +mtmd_audio_tokens_p_ctypes = c_void_p + +# Load the library +_libmtmd_base_name = "mtmd" +_libmtmd_override_path = os.environ.get("mtmd_CPP_LIB") +_libmtmd_base_path = ( + pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib" + if _libmtmd_override_path is None + else pathlib.Path() +) + +_libmtmd = load_shared_library(_libmtmd_base_name, _libmtmd_base_path) +ctypes_function = ctypes_function_for_shared_library(_libmtmd) + +# Add core functions +@ctypes_function( + "mtmd_context_params_default", + [], + mtmd_context_params, +) +def mtmd_context_params_default() -> mtmd_context_params: + ... + +@ctypes_function( + "mtmd_init_from_file", + [c_char_p, llama_cpp.llama_model_p_ctypes, mtmd_context_params], + mtmd_context_p_ctypes, +) +def mtmd_init_from_file( + mmproj_fname: bytes, + text_model: llama_cpp.llama_model_p, + ctx_params: mtmd_context_params, + /, +) -> Optional[mtmd_context_p]: + ... + +@ctypes_function( + "mtmd_free", + [mtmd_context_p_ctypes], + None, +) +def mtmd_free(ctx: mtmd_context_p, /): + ... + +@ctypes_function( + "mtmd_default_marker", + [], + c_char_p, +) +def mtmd_default_marker() -> bytes: + ... + +################################################ +# mtmd.h +################################################ + +@ctypes_function( + "mtmd_tokenize", + [ + mtmd_context_p_ctypes, + mtmd_input_chunks_p_ctypes, + POINTER(mtmd_input_text), + POINTER(mtmd_bitmap_p_ctypes), + c_size_t, + ], + c_int, +) +def mtmd_tokenize( + ctx: mtmd_context_p, + output: mtmd_input_chunks_p, + text: "CtypesPointer[mtmd_input_text]", + bitmaps: "CtypesArray[mtmd_bitmap_p_ctypes]", + n_bitmaps: Union[c_size_t, int], + /, +) -> int: + ... + +@ctypes_function( + "mtmd_encode_chunk", + [mtmd_context_p_ctypes, mtmd_input_chunk_p_ctypes], + c_int, +) +def mtmd_encode_chunk(ctx: mtmd_context_p, chunk: mtmd_input_chunk_p, /) -> int: + ... + +@ctypes_function("mtmd_get_output_embd", [mtmd_context_p_ctypes], POINTER(c_float)) +def mtmd_get_output_embd( + ctx: mtmd_context_p, / +) -> "CtypesPointer[c_float]": + ... + +@ctypes_function("mtmd_decode_use_non_causal", [mtmd_context_p_ctypes], c_bool) +def mtmd_decode_use_non_causal(ctx: mtmd_context_p, /) -> bool: + ... + +@ctypes_function("mtmd_decode_use_mrope", [mtmd_context_p_ctypes], c_bool) +def mtmd_decode_use_mrope(ctx: mtmd_context_p, /) -> bool: + ... + +@ctypes_function("mtmd_support_vision", [mtmd_context_p_ctypes], c_bool) +def mtmd_support_vision(ctx: mtmd_context_p, /) -> bool: + ... + +@ctypes_function("mtmd_support_audio", [mtmd_context_p_ctypes], c_bool) +def mtmd_support_audio(ctx: mtmd_context_p, /) -> bool: + ... + +@ctypes_function("mtmd_get_audio_bitrate", [mtmd_context_p_ctypes], c_int) +def mtmd_get_audio_bitrate(ctx: mtmd_context_p, /) -> int: + ... + +# mtmd_bitmap + +@ctypes_function( + "mtmd_bitmap_init", + [c_uint32, c_uint32, POINTER(c_uint8)], + mtmd_bitmap_p_ctypes, +) +def mtmd_bitmap_init( + nx: Union[c_uint32, int], + ny: Union[c_uint32, int], + data: "CtypesArray[c_uint8]", + /, +) -> Optional[mtmd_bitmap_p]: + ... + +@ctypes_function( + "mtmd_bitmap_init_from_audio", + [c_size_t, POINTER(c_float)], + mtmd_bitmap_p_ctypes, +) +def mtmd_bitmap_init_from_audio( + n_samples: Union[c_size_t, int], + data: "CtypesArray[c_float]", + /, +) -> Optional[mtmd_bitmap_p]: + ... + +@ctypes_function( + "mtmd_bitmap_get_nx", + [mtmd_bitmap_p_ctypes], + c_uint32, +) +def mtmd_bitmap_get_nx(bitmap: mtmd_bitmap_p, /) -> int: + ... + +@ctypes_function( + "mtmd_bitmap_get_ny", + [mtmd_bitmap_p_ctypes], + c_uint32, +) +def mtmd_bitmap_get_ny(bitmap: mtmd_bitmap_p, /) -> int: + ... + +@ctypes_function( + "mtmd_bitmap_get_data", + [mtmd_bitmap_p_ctypes], + POINTER(c_uint8), +) +def mtmd_bitmap_get_data( + bitmap: mtmd_bitmap_p, / +) -> "CtypesPointer[c_uint8]": + ... + +@ctypes_function( + "mtmd_bitmap_get_n_bytes", + [mtmd_bitmap_p_ctypes], + c_size_t, +) +def mtmd_bitmap_get_n_bytes(bitmap: mtmd_bitmap_p, /) -> int: + ... + +@ctypes_function( + "mtmd_bitmap_is_audio", + [mtmd_bitmap_p_ctypes], + c_bool, +) +def mtmd_bitmap_is_audio(bitmap: mtmd_bitmap_p, /) -> bool: + ... + +@ctypes_function( + "mtmd_bitmap_get_id", + [mtmd_bitmap_p_ctypes], + c_char_p, +) +def mtmd_bitmap_get_id(bitmap: mtmd_bitmap_p, /) -> bytes: + ... + +@ctypes_function( + "mtmd_bitmap_set_id", + [mtmd_bitmap_p_ctypes, c_char_p], + None, +) +def mtmd_bitmap_set_id(bitmap: mtmd_bitmap_p, id: bytes, /): + ... + +@ctypes_function( + "mtmd_bitmap_free", + [mtmd_bitmap_p_ctypes], + None, +) +def mtmd_bitmap_free(bitmap: mtmd_bitmap_p, /): + ... + +# mtmd_input_chunks + +@ctypes_function("mtmd_input_chunks_init", [], mtmd_input_chunks_p_ctypes) +def mtmd_input_chunks_init() -> Optional[mtmd_input_chunks_p]: + ... + +@ctypes_function("mtmd_input_chunks_size", [mtmd_input_chunks_p_ctypes], c_size_t) +def mtmd_input_chunks_size(chunks: mtmd_input_chunks_p, /) -> int: + ... + +@ctypes_function( + "mtmd_input_chunks_get", + [mtmd_input_chunks_p_ctypes, c_size_t], + mtmd_input_chunk_p_ctypes, +) +def mtmd_input_chunks_get( + chunks: mtmd_input_chunks_p, idx: Union[c_size_t, int], / +) -> Optional[mtmd_input_chunk_p]: + ... + +@ctypes_function("mtmd_input_chunks_free", [mtmd_input_chunks_p_ctypes], None) +def mtmd_input_chunks_free(chunks: mtmd_input_chunks_p, /): + ... + +# mtmd_input_chunk + +@ctypes_function( + "mtmd_input_chunk_get_type", [mtmd_input_chunk_p_ctypes], mtmd_input_chunk_type +) +def mtmd_input_chunk_get_type(chunk: mtmd_input_chunk_p, /) -> int: + ... + +@ctypes_function( + "mtmd_input_chunk_get_tokens_text", + [mtmd_input_chunk_p_ctypes, POINTER(c_size_t)], + POINTER(llama_cpp.llama_token), +) +def mtmd_input_chunk_get_tokens_text( + chunk: mtmd_input_chunk_p, n_tokens_output: "CtypesPointer[c_size_t]", / +) -> "CtypesPointer[llama_token]": + ... + +@ctypes_function( + "mtmd_input_chunk_get_tokens_image", + [mtmd_input_chunk_p_ctypes], + mtmd_image_tokens_p_ctypes, +) +def mtmd_input_chunk_get_tokens_image( + chunk: mtmd_input_chunk_p, / +) -> Optional[mtmd_image_tokens_p]: + ... + +@ctypes_function( + "mtmd_input_chunk_get_n_tokens", [mtmd_input_chunk_p_ctypes], c_size_t +) +def mtmd_input_chunk_get_n_tokens(chunk: mtmd_input_chunk_p, /) -> int: + ... + +@ctypes_function( + "mtmd_input_chunk_get_n_pos", [mtmd_input_chunk_p_ctypes], llama_cpp.llama_pos +) +def mtmd_input_chunk_get_n_pos(chunk: mtmd_input_chunk_p, /) -> "llama_pos": + ... + +@ctypes_function("mtmd_input_chunk_get_id", [mtmd_input_chunk_p_ctypes], c_char_p) +def mtmd_input_chunk_get_id(chunk: mtmd_input_chunk_p, /) -> bytes: + ... + +@ctypes_function( + "mtmd_input_chunk_copy", [mtmd_input_chunk_p_ctypes], mtmd_input_chunk_p_ctypes +) +def mtmd_input_chunk_copy( + chunk: mtmd_input_chunk_p, / +) -> Optional[mtmd_input_chunk_p]: + ... + +@ctypes_function("mtmd_input_chunk_free", [mtmd_input_chunk_p_ctypes], None) +def mtmd_input_chunk_free(chunk: mtmd_input_chunk_p, /): + ... + +# mtmd_image_tokens + +@ctypes_function( + "mtmd_image_tokens_get_n_tokens", [mtmd_image_tokens_p_ctypes], c_size_t +) +def mtmd_image_tokens_get_n_tokens(image_tokens: mtmd_image_tokens_p, /) -> int: + ... + +@ctypes_function("mtmd_image_tokens_get_nx", [mtmd_image_tokens_p_ctypes], c_size_t) +def mtmd_image_tokens_get_nx(image_tokens: mtmd_image_tokens_p, /) -> int: + ... + +@ctypes_function("mtmd_image_tokens_get_ny", [mtmd_image_tokens_p_ctypes], c_size_t) +def mtmd_image_tokens_get_ny(image_tokens: mtmd_image_tokens_p, /) -> int: + ... + +@ctypes_function("mtmd_image_tokens_get_id", [mtmd_image_tokens_p_ctypes], c_char_p) +def mtmd_image_tokens_get_id(image_tokens: mtmd_image_tokens_p, /) -> bytes: + ... + +@ctypes_function( + "mtmd_image_tokens_get_n_pos", [mtmd_image_tokens_p_ctypes], llama_cpp.llama_pos +) +def mtmd_image_tokens_get_n_pos( + image_tokens: mtmd_image_tokens_p, / +) -> "llama_pos": + ... + +# New helper functions for bitmap handling +@ctypes_function( + "mtmd_helper_bitmap_init_from_file", + [mtmd_context_p_ctypes, c_char_p], + mtmd_bitmap_p_ctypes, +) +def mtmd_helper_bitmap_init_from_file( + ctx: mtmd_context_p, + fname: bytes, + /, +) -> Optional[mtmd_bitmap_p]: + ... + +@ctypes_function( + "mtmd_helper_eval_chunks", + [mtmd_context_p_ctypes, llama_cpp.llama_context_p_ctypes, mtmd_input_chunks_p_ctypes, llama_cpp.llama_pos, c_int32, c_int, c_bool, POINTER(llama_cpp.llama_pos)], + c_int, +) +def mtmd_helper_eval_chunks( + ctx: mtmd_context_p, + lctx: llama_cpp.llama_context_p, + chunks: mtmd_input_chunks_p, + n_past: llama_cpp.llama_pos, + seq_id: int, + n_batch: int, + logits_last: bool, + n_past_out: "CtypesPointer[llama_cpp.llama_pos]", + /, +) -> int: + ... + +# Audio token structure +class mtmd_audio_tokens(Structure): + _fields_ = [ + ("n_tokens", c_uint32), + ("batch_f32", c_void_p), # clip_image_f32_batch + ("id", c_char_p), + ] + +mtmd_audio_tokens_p = NewType("mtmd_audio_tokens_p", int) +mtmd_audio_tokens_p_ctypes = c_void_p + +# Update mtmd_input_chunk to include audio tokens +class mtmd_input_chunk(Structure): + _fields_ = [ + ("type", mtmd_input_chunk_type), + ("tokens_text", POINTER(llama_cpp.llama_token)), + ("tokens_image", mtmd_image_tokens_p_ctypes), + ("tokens_audio", mtmd_audio_tokens_p_ctypes), + ] + +# Helper class for managing bitmaps +class BitmapManager: + def __init__(self): + self.entries: List[mtmd_bitmap_p] = [] + + def c_ptr(self) -> "CtypesArray[mtmd_bitmap_p_ctypes]": + arr_type = (mtmd_bitmap_p_ctypes * len(self.entries)) + return arr_type(*(entry for entry in self.entries)) + + def clear(self): + for bitmap in self.entries: + mtmd_bitmap_free(bitmap) + self.entries.clear() + + def add_from_memory(self, ctx: mtmd_context_p, data: bytes) -> bool: + import numpy as np + data_array = np.frombuffer(data, dtype=np.uint8) + bitmap = mtmd_helper_bitmap_init_from_buf(ctx, data_array.ctypes.data_as(POINTER(c_uint8)), len(data)) + if bitmap is None: + return False + self.entries.append(bitmap) + return True + + def __del__(self): + self.clear() + +# Helper class for managing input chunks +class InputChunksManager: + def __init__(self, chunks: mtmd_input_chunks_p): + self.ptr = chunks + + def __del__(self): + if self.ptr: + mtmd_input_chunks_free(self.ptr) + + def size(self) -> int: + return mtmd_input_chunks_size(self.ptr) + + def get(self, idx: int) -> Optional[mtmd_input_chunk_p]: + return mtmd_input_chunks_get(self.ptr, idx) + +@ctypes_function( + "mtmd_helper_get_n_tokens", + [mtmd_input_chunks_p_ctypes], + c_size_t, +) +def mtmd_helper_get_n_tokens(chunks: mtmd_input_chunks_p, /) -> int: + ... + +@ctypes_function( + "mtmd_helper_get_n_pos", + [mtmd_input_chunks_p_ctypes], + llama_cpp.llama_pos, +) +def mtmd_helper_get_n_pos(chunks: mtmd_input_chunks_p, /) -> "llama_pos": + ... + +@ctypes_function( + "mtmd_helper_bitmap_init_from_buf", + [mtmd_context_p_ctypes, POINTER(c_uint8), c_size_t], + mtmd_bitmap_p_ctypes, +) +def mtmd_helper_bitmap_init_from_buf( + ctx: mtmd_context_p, + buf: "CtypesArray[c_uint8]", + len: Union[c_size_t, int], + /, +) -> Optional[mtmd_bitmap_p]: + ... + +@ctypes_function( + "mtmd_helper_decode_image_chunk", + [mtmd_context_p_ctypes, llama_cpp.llama_context_p_ctypes, mtmd_input_chunk_p_ctypes, POINTER(c_float), llama_cpp.llama_pos, llama_cpp.llama_seq_id, c_int32, POINTER(llama_cpp.llama_pos)], + c_int32, +) +def mtmd_helper_decode_image_chunk( + ctx: mtmd_context_p, + lctx: llama_cpp.llama_context_p, + chunk: mtmd_input_chunk_p, + encoded_embd: "CtypesPointer[c_float]", + n_past: llama_cpp.llama_pos, + seq_id: llama_cpp.llama_seq_id, + n_batch: int, + new_n_past: "CtypesPointer[llama_cpp.llama_pos]", + /, +) -> int: + ... + +@ctypes_function( + "mtmd_helper_eval_chunk_single", + [mtmd_context_p_ctypes, llama_cpp.llama_context_p_ctypes, mtmd_input_chunk_p_ctypes, llama_cpp.llama_pos, llama_cpp.llama_seq_id, c_int32, c_bool, POINTER(llama_cpp.llama_pos)], + c_int32, +) +def mtmd_helper_eval_chunk_single( + ctx: mtmd_context_p, + lctx: llama_cpp.llama_context_p, + chunk: mtmd_input_chunk_p, + n_past: llama_cpp.llama_pos, + seq_id: llama_cpp.llama_seq_id, + n_batch: int, + logits_last: bool, + new_n_past: "CtypesPointer[llama_cpp.llama_pos]", + /, +) -> int: + ... diff --git a/llama_cpp/server/settings.py b/llama_cpp/server/settings.py index 13c9512..33b7a1d 100644 --- a/llama_cpp/server/settings.py +++ b/llama_cpp/server/settings.py @@ -172,11 +172,11 @@ class ModelSettings(BaseSettings): description="Number of tokens to predict using the draft model.", ) # KV Cache Quantization - type_k: Optional[int] = Field( + type_k: Optional[llama_cpp.GGMLType] = Field( default=None, description="Type of the key cache quantization.", ) - type_v: Optional[int] = Field( + type_v: Optional[llama_cpp.GGMLType] = Field( default=None, description="Type of the value cache quantization.", ) diff --git a/pyproject.toml b/pyproject.toml index 9983ef7..1f0aab5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,8 @@ test = [ "sse-starlette>=1.6.1", "starlette-context>=0.3.6,<0.4", "pydantic-settings>=2.0.1", - "huggingface-hub>=0.23.0" + "huggingface-hub>=0.23.0", + "typeguard>=4.2.1", ] dev = [ "black>=23.3.0", diff --git a/tests/monalisa.jpg b/tests/monalisa.jpg new file mode 100644 index 0000000..782ee4f Binary files /dev/null and b/tests/monalisa.jpg differ diff --git a/tests/test_llama_chat_format.py b/tests/test_llama_chat_format.py index f031bf7..42bbac1 100644 --- a/tests/test_llama_chat_format.py +++ b/tests/test_llama_chat_format.py @@ -1,14 +1,29 @@ import json +import os +import platform +from collections.abc import Iterator +from typing import cast +import pytest import jinja2 +from typeguard import ForwardRefPolicy, check_type from llama_cpp import ( ChatCompletionRequestUserMessage, + Llama, + llama_chat_format, + llama_supports_gpu_offload, + llama_types ) -import llama_cpp.llama_types as llama_types -import llama_cpp.llama_chat_format as llama_chat_format - from llama_cpp.llama_chat_format import hf_tokenizer_config_to_chat_formatter +from llama_cpp.llama_types import ( + ChatCompletionRequestMessage, + ChatCompletionTool, + ChatCompletionToolChoiceOption, + CreateChatCompletionResponse, + CreateChatCompletionStreamResponse, +) + def test_mistral_instruct(): chat_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" @@ -87,3 +102,118 @@ def test_hf_tokenizer_config_str_to_chat_formatter(): ) assert chat_formatter_respoonse.prompt == ("[INST] Hello, world! [/INST]" "") + + +def is_accelerator_available() -> bool: + """Check if an accelerator is available.""" + return llama_supports_gpu_offload() or (os.cpu_count() or 1) >= 8 + + +@pytest.mark.parametrize( + "stream", + [ + pytest.param(True, id="stream=True"), + pytest.param(False, id="stream=False"), + ], +) +@pytest.mark.parametrize( + "tool_choice", + [ + pytest.param("none", id="tool_choice=none"), + pytest.param("auto", id="tool_choice=auto"), + pytest.param( + {"type": "function", "function": {"name": "get_weather"}}, id="tool_choice=fixed" + ), + ], +) +@pytest.mark.parametrize( + "user_prompt_expected_tool_calls", + [ + pytest.param( + ("Is 7 a prime number?", 0), + id="expected_tool_calls=0", + ), + pytest.param( + ("What's the weather like in Paris today?", 1), + id="expected_tool_calls=1", + ), + pytest.param( + ("What's the weather like in Paris today? What about New York?", 2), + id="expected_tool_calls=2", + ), + ], +) +@pytest.mark.parametrize( + "llm_repo_id", + [ + pytest.param("bartowski/Llama-3.2-3B-Instruct-GGUF", id="llama_3.2_3B"), + pytest.param( + "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", + id="llama_3.1_8B", + marks=pytest.mark.skipif( + not is_accelerator_available(), reason="Accelerator not available" + ), + ), + ], +) +@pytest.mark.skipif( + platform.system() == "Darwin" and (os.cpu_count() or 1) < 8, + reason="Insufficient resources on macOS", +) +def test_llama_cpp_python_tool_use( + llm_repo_id: str, + user_prompt_expected_tool_calls: tuple[str, int], + tool_choice: ChatCompletionToolChoiceOption, + stream: bool, +) -> None: + """Test the upgraded chatml-function-calling llama-cpp-python chat handler.""" + user_prompt, expected_tool_calls = user_prompt_expected_tool_calls + if isinstance(tool_choice, dict) and expected_tool_calls == 0: + pytest.skip("Nonsensical") + llm = Llama.from_pretrained( + repo_id=llm_repo_id, + filename="*Q4_K_M.gguf", + n_ctx=4096, + n_gpu_layers=-1, + verbose=False, + chat_format="chatml-function-calling", + ) + messages: list[ChatCompletionRequestMessage] = [{"role": "user", "content": user_prompt}] + tools: list[ChatCompletionTool] = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the weather for a location.", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string", "description": "A city name."}}, + }, + }, + } + ] + response = llm.create_chat_completion( + messages=messages, tools=tools, tool_choice=tool_choice, stream=stream + ) + if stream: + response = cast(Iterator[CreateChatCompletionStreamResponse], response) + num_tool_calls = 0 + for chunk in response: + check_type(chunk, CreateChatCompletionStreamResponse) + tool_calls = chunk["choices"][0]["delta"].get("tool_calls") + if isinstance(tool_calls, list): + num_tool_calls = max(tool_call["index"] for tool_call in tool_calls) + 1 + assert num_tool_calls == (expected_tool_calls if tool_choice != "none" else 0) + else: + response = cast(CreateChatCompletionResponse, response) + check_type( + response, CreateChatCompletionResponse, forward_ref_policy=ForwardRefPolicy.IGNORE + ) + if expected_tool_calls == 0 or tool_choice == "none": + assert response["choices"][0]["message"].get("tool_calls") is None + else: + assert len(response["choices"][0]["message"]["tool_calls"]) == expected_tool_calls + assert all( + tool_call["function"]["name"] == tools[0]["function"]["name"] + for tool_call in response["choices"][0]["message"]["tool_calls"] + ) diff --git a/tests/test_llava.py b/tests/test_llava.py new file mode 100644 index 0000000..2be6017 --- /dev/null +++ b/tests/test_llava.py @@ -0,0 +1,80 @@ +import multiprocessing +import ctypes + +from huggingface_hub import hf_hub_download + +import pytest + +import llama_cpp + +@pytest.fixture +def mmproj_model_path(): + repo_id = "second-state/Llava-v1.5-7B-GGUF" + filename = "llava-v1.5-7b-mmproj-model-f16.gguf" + model_path = hf_hub_download(repo_id, filename) + return model_path + +@pytest.fixture +def llava_cpp_model_path(): + repo_id = "second-state/Llava-v1.5-7B-GGUF" + filename = "llava-v1.5-7b-Q8_0.gguf" + model_path = hf_hub_download(repo_id, filename) + return model_path + +def test_real_llava(llava_cpp_model_path, mmproj_model_path): + print("initializing model") + model = llama_cpp.Llama( + llava_cpp_model_path, + n_ctx=2048, + n_batch=512, + n_threads=multiprocessing.cpu_count(), + n_threads_batch=multiprocessing.cpu_count(), + logits_all=False, + verbose=False, + ) + + # Initialize the LLaVA chat handler + from llama_cpp.llama_chat_format import Llava15ChatHandler + print("initializing chat handler") + chat_handler = Llava15ChatHandler(clip_model_path=mmproj_model_path, llama_model=model) + + # Create a chat message with the image + print("creating chat message") + messages = [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": "./tests/monalisa.jpg" + }, + { + "type": "text", + "text": "Do you know who drew this painting?" + } + ] + } + ] + + # Generate response + print("generating response") + response = chat_handler( + llama=model, + messages=messages, + max_tokens=200, + temperature=0.2, + top_p=0.95, + stream=False + ) + + print("response", response) + # Check that we got a response + assert response is not None + assert "choices" in response + assert len(response["choices"]) > 0 + assert "message" in response["choices"][0] + assert "content" in response["choices"][0]["message"] + + # The response should mention Leonardo da Vinci + content = response["choices"][0]["message"]["content"].lower() + assert "leonardo" in content and "vinci" in content # Artist name should be in response diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 7841fc7..f13fa9b 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 7841fc723e059d1fd9640e5c0ef19050fcc7c698 +Subproject commit f13fa9b2b523e22ba58fcf4c468f670d8c98d912