Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve client.zero_shot_classification() #2340

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
96173e6
add hypothesis template like in https://huggingface.co/docs/transform…
MoritzLaurer Jun 14, 2024
884a3d9
test removal of ",".join and min 2 class requirement
MoritzLaurer Jun 14, 2024
040a18b
remove references to min 2 labels
MoritzLaurer Jun 14, 2024
73506af
improved doc string and added example
MoritzLaurer Jun 14, 2024
9a05817
typo fix
MoritzLaurer Jun 14, 2024
9c4b1bf
remove white spaces in doc string
MoritzLaurer Jun 14, 2024
c554c61
suggestions implemented
MoritzLaurer Jun 17, 2024
d75f348
applied make style
MoritzLaurer Jun 17, 2024
af82918
update asyncinferenceclient
MoritzLaurer Jun 17, 2024
4f6398b
rerun python utils/generate_async_inference_client.py --update withou…
MoritzLaurer Jun 17, 2024
656a997
print actual error message when failing to load a submodule (#2342)
kallewoof Jun 17, 2024
22a5621
🌐 [i18n-KO] Translated `package_reference/environment_variables.md` t…
jungnerd Jun 17, 2024
8f98cec
Do not raise on `.resume()` if Inference Endpoint is already running …
Wauplin Jun 17, 2024
4142901
🌐 [i18n-KO] Translated `package_reference/webhooks_server.md` to Kore…
fabxoe Jun 17, 2024
e9d1acd
Support `resource_group_id` in `create_repo` (#2324)
Wauplin Jun 24, 2024
ae4fed1
[i18n-FR] Translated "Integrations" to french (sub PR of 1900) (#2329)
JibrilEl Jun 26, 2024
c282c95
Update _errors.py (#2354)
qgallouedec Jun 26, 2024
7832db5
[HfFileSystem] Faster `fs.walk()` (#2346)
lhoestq Jun 26, 2024
d298397
fix docs
Wauplin Jun 27, 2024
d7def91
Add proxy support on async client (#2350)
noech373 Jun 27, 2024
0c262f7
add a diagram about hf:// URLs (#2358)
severo Jun 28, 2024
4ba21f4
🌐 [i18n-KO] Translated `guides/manage-cache.md` to Korean (#2347)
cjfghk5697 Jul 1, 2024
85ffd0e
Update ruff in CI (#2365)
Wauplin Jul 1, 2024
11d4257
Promote chat_completion in inference guide (#2366)
Wauplin Jul 2, 2024
00dc599
Add `prompt_name` to feature-extraction + update types (#2363)
Wauplin Jul 2, 2024
9a1aa57
put model arg at the end
MoritzLaurer Jul 3, 2024
59fd9b0
added doc strings for hypothesis_template and made it optional
MoritzLaurer Jul 3, 2024
37577b1
ran make style
MoritzLaurer Jul 3, 2024
e76105d
ran python utils/generate_async_inference_client.py --update
MoritzLaurer Jul 3, 2024
0f45cad
hypothesis_template default to None
MoritzLaurer Jul 3, 2024
b1d59f5
make style and update async_client
MoritzLaurer Jul 3, 2024
8fec267
Merge branch 'main' into zeroshot-classification-client-fix
Wauplin Jul 3, 2024
f479f08
fix generate_async_inference_client
Wauplin Jul 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 40 additions & 12 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2412,7 +2412,13 @@ def visual_question_answering(
return VisualQuestionAnsweringOutputElement.parse_obj_as_list(response)

def zero_shot_classification(
self, text: str, labels: List[str], *, multi_label: bool = False, model: Optional[str] = None
self,
text: str,
labels: List[str],
*,
multi_label: bool = False,
hypothesis_template: Optional[str] = None,
model: Optional[str] = None,
) -> List[ZeroShotClassificationOutputElement]:
"""
Provide as input a text and a set of candidate labels to classify the input text.
Expand All @@ -2421,9 +2427,15 @@ def zero_shot_classification(
text (`str`):
The input text to classify.
labels (`List[str]`):
List of string possible labels. There must be at least 2 labels.
List of strings. Each string is the verbalization of a possible label for the input text.
multi_label (`bool`):
MoritzLaurer marked this conversation as resolved.
Show resolved Hide resolved
Boolean that is set to True if classes can overlap.
Boolean. If True, the probability for each label is evaluated independently and multiple labels can have a probability close to 1 simultaneously or all probabilities can be close to 0.
If False, the labels are considered mutually exclusive and the probability over all labels always sums to 1. Defaults to False.
hypothesis_template (`str`, *optional*):
A template sentence string with curly brackets to which the label strings are added. The label strings are added at the position of the curly brackets "{}".
Zero-shot classifiers are based on NLI models, which evaluate if a hypothesis is entailed in another text or not.
For example, with hypothesis_template="This text is about {}." and labels=["economics", "politics"], the system internally creates the two hypotheses "This text is about economics." and "This text is about politics.".
The model then evaluates for both hypotheses if they are entailed in the provided `text` or not.
model (`str`, *optional*):
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
Expand All @@ -2437,7 +2449,7 @@ def zero_shot_classification(
`HTTPError`:
If the request fails with an HTTP error status code other than HTTP 503.

Example:
Example with `multi_label=False`:
```py
>>> from huggingface_hub import InferenceClient
>>> client = InferenceClient()
Expand All @@ -2464,21 +2476,37 @@ def zero_shot_classification(
ZeroShotClassificationOutputElement(label='robots', score=0.00030448526376858354),
]
```

Example with `multi_label=True` and a custom `hypothesis_template`:
```py
>>> from huggingface_hub import InferenceClient
>>> client = InferenceClient()
>>> client.zero_shot_classification(
... text="I really like our dinner and I'm very happy. I don't like the weather though.",
... labels=["positive", "negative", "pessimistic", "optimistic"],
... multi_label=True,
... hypothesis_template="This text is {} towards the weather"
... )
[
ZeroShotClassificationOutputElement(label='negative', score=0.9231801629066467),
ZeroShotClassificationOutputElement(label='pessimistic', score=0.8760990500450134),
ZeroShotClassificationOutputElement(label='optimistic', score=0.0008674879791215062),
ZeroShotClassificationOutputElement(label='positive', score=0.0005250611575320363)
]
```
"""
# Raise ValueError if input is less than 2 labels
if len(labels) < 2:
raise ValueError("You must specify at least 2 classes to compare.")

parameters = {"candidate_labels": labels, "multi_label": multi_label}
if hypothesis_template is not None:
parameters["hypothesis_template"] = hypothesis_template

response = self.post(
json={
"inputs": text,
"parameters": {
"candidate_labels": ",".join(labels),
"multi_label": multi_label,
},
"parameters": parameters,
},
model=model,
task="zero-shot-classification",
model=model,
)
output = _bytes_to_dict(response)
return [
Expand Down
63 changes: 46 additions & 17 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ async def chat_completion(
>>> messages = [
... {
... "role": "system",
... "content": "Don't make assumptions about what values to plug into functions. Ask async for clarification if a user request is ambiguous.",
... "content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.",
... },
... {
... "role": "user",
Expand Down Expand Up @@ -877,7 +877,7 @@ async def conversational(
>>> client = AsyncInferenceClient()
>>> output = await client.conversational("Hi, who are you?")
>>> output
{'generated_text': 'I am the one who knocks.', 'conversation': {'generated_responses': ['I am the one who knocks.'], 'past_user_inputs': ['Hi, who are you?']}, 'warnings': ['Setting `pad_token_id` to `eos_token_id`:50256 async for open-end generation.']}
{'generated_text': 'I am the one who knocks.', 'conversation': {'generated_responses': ['I am the one who knocks.'], 'past_user_inputs': ['Hi, who are you?']}, 'warnings': ['Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.']}
>>> await client.conversational(
... "Wow, that's scary!",
... generated_responses=output["conversation"]["generated_responses"],
Expand Down Expand Up @@ -1960,7 +1960,7 @@ async def text_generation(
>>> await client.text_generation("The huggingface_hub library is ", max_new_tokens=12)
'100% open source and built to be easy to use.'

# Case 2: iterate over the generated tokens. Useful async for large generation.
# Case 2: iterate over the generated tokens. Useful for large generation.
>>> async for token in await client.text_generation("The huggingface_hub library is ", max_new_tokens=12, stream=True):
... print(token)
100
Expand Down Expand Up @@ -2444,7 +2444,13 @@ async def visual_question_answering(
return VisualQuestionAnsweringOutputElement.parse_obj_as_list(response)

async def zero_shot_classification(
self, text: str, labels: List[str], *, multi_label: bool = False, model: Optional[str] = None
self,
text: str,
labels: List[str],
*,
multi_label: bool = False,
hypothesis_template: Optional[str] = None,
model: Optional[str] = None,
) -> List[ZeroShotClassificationOutputElement]:
"""
Provide as input a text and a set of candidate labels to classify the input text.
Expand All @@ -2453,9 +2459,15 @@ async def zero_shot_classification(
text (`str`):
The input text to classify.
labels (`List[str]`):
List of string possible labels. There must be at least 2 labels.
List of strings. Each string is the verbalization of a possible label for the input text.
multi_label (`bool`):
Boolean that is set to True if classes can overlap.
Boolean. If True, the probability for each label is evaluated independently and multiple labels can have a probability close to 1 simultaneously or all probabilities can be close to 0.
If False, the labels are considered mutually exclusive and the probability over all labels always sums to 1. Defaults to False.
hypothesis_template (`str`, *optional*):
A template sentence string with curly brackets to which the label strings are added. The label strings are added at the position of the curly brackets "{}".
Zero-shot classifiers are based on NLI models, which evaluate if a hypothesis is entailed in another text or not.
For example, with hypothesis_template="This text is about {}." and labels=["economics", "politics"], the system internally creates the two hypotheses "This text is about economics." and "This text is about politics.".
The model then evaluates for both hypotheses if they are entailed in the provided `text` or not.
model (`str`, *optional*):
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
Expand All @@ -2469,15 +2481,15 @@ async def zero_shot_classification(
`aiohttp.ClientResponseError`:
If the request fails with an HTTP error status code other than HTTP 503.

Example:
Example with `multi_label=False`:
```py
# Must be run in an async context
>>> from huggingface_hub import AsyncInferenceClient
>>> client = AsyncInferenceClient()
>>> text = (
... "A new model offers an explanation async for how the Galilean satellites formed around the solar system's"
... "A new model offers an explanation for how the Galilean satellites formed around the solar system's"
... "largest world. Konstantin Batygin did not set out to solve one of the solar system's most puzzling"
... " mysteries when he went async for a run up a hill in Nice, France."
... " mysteries when he went for a run up a hill in Nice, France."
... )
>>> labels = ["space & cosmos", "scientific discovery", "microbiology", "robots", "archeology"]
>>> await client.zero_shot_classification(text, labels)
Expand All @@ -2497,21 +2509,38 @@ async def zero_shot_classification(
ZeroShotClassificationOutputElement(label='robots', score=0.00030448526376858354),
]
```

Example with `multi_label=True` and a custom `hypothesis_template`:
```py
# Must be run in an async context
>>> from huggingface_hub import AsyncInferenceClient
>>> client = AsyncInferenceClient()
>>> await client.zero_shot_classification(
... text="I really like our dinner and I'm very happy. I don't like the weather though.",
... labels=["positive", "negative", "pessimistic", "optimistic"],
... multi_label=True,
... hypothesis_template="This text is {} towards the weather"
... )
[
ZeroShotClassificationOutputElement(label='negative', score=0.9231801629066467),
ZeroShotClassificationOutputElement(label='pessimistic', score=0.8760990500450134),
ZeroShotClassificationOutputElement(label='optimistic', score=0.0008674879791215062),
ZeroShotClassificationOutputElement(label='positive', score=0.0005250611575320363)
]
```
"""
# Raise ValueError if input is less than 2 labels
if len(labels) < 2:
raise ValueError("You must specify at least 2 classes to compare.")

parameters = {"candidate_labels": labels, "multi_label": multi_label}
if hypothesis_template is not None:
parameters["hypothesis_template"] = hypothesis_template

response = await self.post(
json={
"inputs": text,
"parameters": {
"candidate_labels": ",".join(labels),
"multi_label": multi_label,
},
"parameters": parameters,
},
model=model,
task="zero-shot-classification",
model=model,
)
output = _bytes_to_dict(response)
return [
Expand Down
4 changes: 2 additions & 2 deletions utils/generate_async_inference_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,15 +377,15 @@ def _update_example_code_block(code_block: str) -> str:
code_block = "\n # Must be run in an async context" + code_block
code_block = code_block.replace("InferenceClient", "AsyncInferenceClient")
code_block = code_block.replace("client.", "await client.")
code_block = code_block.replace(" for ", " async for ")
code_block = code_block.replace(">>> for ", ">>> async for ")
return code_block


def _update_examples_in_public_methods(code: str) -> str:
for match in re.finditer(
r"""
\n\s*
Example:\n\s* # example section
Example.*?:\n\s* # example section
```py # start
(.*?) # code block
``` # end
Expand Down
Loading