Skip to content

Commit

Permalink
fix: Don't store scores internally unless logits_all=True. Reduces me…
Browse files Browse the repository at this point in the history
…mory requirements for large context. Closes #1542
  • Loading branch information
abetlen committed Sep 19, 2024
1 parent 22cedad commit 29afcfd
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def free_lora_adapter():
self.n_tokens = 0
self.input_ids: npt.NDArray[np.intc] = np.ndarray((n_ctx,), dtype=np.intc)
self.scores: npt.NDArray[np.single] = np.ndarray(
(n_ctx, self._n_vocab), dtype=np.single
(n_ctx if logits_all == True else n_batch, self._n_vocab), dtype=np.single
)

self._mirostat_mu = ctypes.c_float(
Expand Down Expand Up @@ -648,12 +648,14 @@ def eval(self, tokens: Sequence[int]):
)
self.scores[n_past : n_past + n_tokens, :].reshape(-1)[::] = logits
else:
rows = 1
cols = self._n_vocab
logits = np.ctypeslib.as_array(
self._ctx.get_logits(), shape=(rows * cols,)
)
self.scores[n_past + n_tokens - 1, :].reshape(-1)[::] = logits
# rows = 1
# cols = self._n_vocab
# logits = np.ctypeslib.as_array(
# self._ctx.get_logits(), shape=(rows * cols,)
# )
# self.scores[n_past + n_tokens - 1, :].reshape(-1)[::] = logits
# NOTE: Now that sampling is done inside the sampler, logits are only needed for logprobs which requires logits_all
pass
# Update n_tokens
self.n_tokens += n_tokens

Expand Down

0 comments on commit 29afcfd

Please sign in to comment.