Skip to content

Commit

Permalink
Support truncate and normalize in InferenceClient (#2270)
Browse files Browse the repository at this point in the history
* Support truncate and normalize in InferenceClient

* quality
  • Loading branch information
Wauplin authored May 27, 2024
1 parent ce00270 commit 11e692a
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 10 deletions.
22 changes: 20 additions & 2 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,7 +913,14 @@ def document_question_answering(
response = self.post(json=payload, model=model, task="document-question-answering")
return DocumentQuestionAnsweringOutputElement.parse_obj_as_list(response)

def feature_extraction(self, text: str, *, model: Optional[str] = None) -> "np.ndarray":
def feature_extraction(
self,
text: str,
*,
normalize: Optional[bool] = None,
truncate: Optional[bool] = None,
model: Optional[str] = None,
) -> "np.ndarray":
"""
Generate embeddings for a given text.
Expand All @@ -924,6 +931,12 @@ def feature_extraction(self, text: str, *, model: Optional[str] = None) -> "np.n
The model to use for the conversational task. Can be a model ID hosted on the Hugging Face Hub or a URL to
a deployed Inference Endpoint. If not provided, the default recommended conversational model will be used.
Defaults to None.
normalize (`bool`, *optional*):
Whether to normalize the embeddings or not. Defaults to None.
Only available on server powered by Text-Embedding-Inference.
truncate (`bool`, *optional*):
Whether to truncate the embeddings or not. Defaults to None.
Only available on server powered by Text-Embedding-Inference.
Returns:
`np.ndarray`: The embedding representing the input text as a float32 numpy array.
Expand All @@ -945,7 +958,12 @@ def feature_extraction(self, text: str, *, model: Optional[str] = None) -> "np.n
[ 0.28552425, -0.928395 , -1.2077185 , ..., 0.76810825, -2.1069427 , 0.6236161 ]], dtype=float32)
```
"""
response = self.post(json={"inputs": text}, model=model, task="feature-extraction")
payload: Dict = {"inputs": text}
if normalize is not None:
payload["normalize"] = normalize
if truncate is not None:
payload["truncate"] = truncate
response = self.post(json=payload, model=model, task="feature-extraction")
np = _import_numpy()
return np.array(_bytes_to_dict(response), dtype="float32")

Expand Down
22 changes: 20 additions & 2 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,7 +915,14 @@ async def document_question_answering(
response = await self.post(json=payload, model=model, task="document-question-answering")
return DocumentQuestionAnsweringOutputElement.parse_obj_as_list(response)

async def feature_extraction(self, text: str, *, model: Optional[str] = None) -> "np.ndarray":
async def feature_extraction(
self,
text: str,
*,
normalize: Optional[bool] = None,
truncate: Optional[bool] = None,
model: Optional[str] = None,
) -> "np.ndarray":
"""
Generate embeddings for a given text.
Expand All @@ -926,6 +933,12 @@ async def feature_extraction(self, text: str, *, model: Optional[str] = None) ->
The model to use for the conversational task. Can be a model ID hosted on the Hugging Face Hub or a URL to
a deployed Inference Endpoint. If not provided, the default recommended conversational model will be used.
Defaults to None.
normalize (`bool`, *optional*):
Whether to normalize the embeddings or not. Defaults to None.
Only available on server powered by Text-Embedding-Inference.
truncate (`bool`, *optional*):
Whether to truncate the embeddings or not. Defaults to None.
Only available on server powered by Text-Embedding-Inference.
Returns:
`np.ndarray`: The embedding representing the input text as a float32 numpy array.
Expand All @@ -948,7 +961,12 @@ async def feature_extraction(self, text: str, *, model: Optional[str] = None) ->
[ 0.28552425, -0.928395 , -1.2077185 , ..., 0.76810825, -2.1069427 , 0.6236161 ]], dtype=float32)
```
"""
response = await self.post(json={"inputs": text}, model=model, task="feature-extraction")
payload: Dict = {"inputs": text}
if normalize is not None:
payload["normalize"] = normalize
if truncate is not None:
payload["truncate"] = truncate
response = await self.post(json=payload, model=model, task="feature-extraction")
np = _import_numpy()
return np.array(_bytes_to_dict(response), dtype="float32")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from dataclasses import dataclass
from typing import Any, Dict, Optional
from typing import List, Optional, Union

from .base import BaseInferenceType


@dataclass
class FeatureExtractionInput(BaseInferenceType):
"""Inputs for Text Embedding inference"""
"""Feature Extraction Input.
Auto-generated from TEI specs.
For more details, check out
https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tei-import.ts.
"""

inputs: str
"""The text to get the embeddings of"""
parameters: Optional[Dict[str, Any]] = None
"""Additional inference parameters"""
inputs: Union[List[str], str]
normalize: Optional[bool] = None
truncate: Optional[bool] = None

0 comments on commit 11e692a

Please sign in to comment.