Skip to content

Commit

Permalink
Support /info and /health routes (#2269)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wauplin authored May 27, 2024
1 parent 5a8cfcd commit 8873deb
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 0 deletions.
89 changes: 89 additions & 0 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2547,6 +2547,95 @@ def get_recommended_model(task: str) -> str:
)
return model

def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, Any]:
"""
Get information about the deployed endpoint.
This endpoint is only available on endpoints powered by Text-Generation-Inference (TGI) or Text-Embedding-Inference (TEI).
Endpoints powered by `transformers` return an empty payload.
Args:
model (`str`, *optional*):
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
Returns:
`Dict[str, Any]`: Information about the endpoint.
Example:
```py
>>> from huggingface_hub import InferenceClient
>>> client = InferenceClient("meta-llama/Meta-Llama-3-70B-Instruct")
>>> client.get_endpoint_info()
{
'model_id': 'meta-llama/Meta-Llama-3-70B-Instruct',
'model_sha': None,
'model_dtype': 'torch.float16',
'model_device_type': 'cuda',
'model_pipeline_tag': None,
'max_concurrent_requests': 128,
'max_best_of': 2,
'max_stop_sequences': 4,
'max_input_length': 8191,
'max_total_tokens': 8192,
'waiting_served_ratio': 0.3,
'max_batch_total_tokens': 1259392,
'max_waiting_tokens': 20,
'max_batch_size': None,
'validation_workers': 32,
'max_client_batch_size': 4,
'version': '2.0.2',
'sha': 'dccab72549635c7eb5ddb17f43f0b7cdff07c214',
'docker_label': 'sha-dccab72'
}
```
"""
model = model or self.model
if model is None:
raise ValueError("Model id not provided.")
if model.startswith(("http://", "https://")):
url = model.rstrip("/") + "/info"
else:
url = f"{INFERENCE_ENDPOINT}/models/{model}/info"

response = get_session().get(url, headers=self.headers)
hf_raise_for_status(response)
return response.json()

def health_check(self, model: Optional[str] = None) -> bool:
"""
Check the health of the deployed endpoint.
Health check is only available with Inference Endpoints powered by Text-Generation-Inference (TGI) or Text-Embedding-Inference (TEI).
For Inference API, please use [`InferenceClient.get_model_status`] instead.
Args:
model (`str`, *optional*):
URL of the Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
Returns:
`bool`: True if everything is working fine.
Example:
```py
>>> from huggingface_hub import InferenceClient
>>> client = InferenceClient("https://jzgu0buei5.us-east-1.aws.endpoints.huggingface.cloud")
>>> client.health_check()
True
```
"""
model = model or self.model
if model is None:
raise ValueError("Model id not provided.")
if not model.startswith(("http://", "https://")):
raise ValueError(
"Model must be an Inference Endpoint URL. For serverless Inference API, please use `InferenceClient.get_model_status`."
)
url = model.rstrip("/") + "/health"

response = get_session().get(url, headers=self.headers)
return response.status_code == 200

def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
"""
Get the status of a model hosted on the Inference API.
Expand Down
93 changes: 93 additions & 0 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2579,6 +2579,99 @@ def get_recommended_model(task: str) -> str:
)
return model

async def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, Any]:
"""
Get information about the deployed endpoint.
This endpoint is only available on endpoints powered by Text-Generation-Inference (TGI) or Text-Embedding-Inference (TEI).
Endpoints powered by `transformers` return an empty payload.
Args:
model (`str`, *optional*):
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
Returns:
`Dict[str, Any]`: Information about the endpoint.
Example:
```py
# Must be run in an async context
>>> from huggingface_hub import AsyncInferenceClient
>>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-70B-Instruct")
>>> await client.get_endpoint_info()
{
'model_id': 'meta-llama/Meta-Llama-3-70B-Instruct',
'model_sha': None,
'model_dtype': 'torch.float16',
'model_device_type': 'cuda',
'model_pipeline_tag': None,
'max_concurrent_requests': 128,
'max_best_of': 2,
'max_stop_sequences': 4,
'max_input_length': 8191,
'max_total_tokens': 8192,
'waiting_served_ratio': 0.3,
'max_batch_total_tokens': 1259392,
'max_waiting_tokens': 20,
'max_batch_size': None,
'validation_workers': 32,
'max_client_batch_size': 4,
'version': '2.0.2',
'sha': 'dccab72549635c7eb5ddb17f43f0b7cdff07c214',
'docker_label': 'sha-dccab72'
}
```
"""
model = model or self.model
if model is None:
raise ValueError("Model id not provided.")
if model.startswith(("http://", "https://")):
url = model.rstrip("/") + "/info"
else:
url = f"{INFERENCE_ENDPOINT}/models/{model}/info"

async with _import_aiohttp().ClientSession(headers=self.headers) as client:
response = await client.get(url)
response.raise_for_status()
return await response.json()

async def health_check(self, model: Optional[str] = None) -> bool:
"""
Check the health of the deployed endpoint.
Health check is only available with Inference Endpoints powered by Text-Generation-Inference (TGI) or Text-Embedding-Inference (TEI).
For Inference API, please use [`InferenceClient.get_model_status`] instead.
Args:
model (`str`, *optional*):
URL of the Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
Returns:
`bool`: True if everything is working fine.
Example:
```py
# Must be run in an async context
>>> from huggingface_hub import AsyncInferenceClient
>>> client = AsyncInferenceClient("https://jzgu0buei5.us-east-1.aws.endpoints.huggingface.cloud")
>>> await client.health_check()
True
```
"""
model = model or self.model
if model is None:
raise ValueError("Model id not provided.")
if not model.startswith(("http://", "https://")):
raise ValueError(
"Model must be an Inference Endpoint URL. For serverless Inference API, please use `InferenceClient.get_model_status`."
)
url = model.rstrip("/") + "/health"

async with _import_aiohttp().ClientSession(headers=self.headers) as client:
response = await client.get(url)
return response.status == 200

async def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
"""
Get the status of a model hosted on the Inference API.
Expand Down
29 changes: 29 additions & 0 deletions utils/generate_async_inference_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def generate_async_client_code(code: str) -> str:
# Adapt list_deployed_models
code = _adapt_list_deployed_models(code)

# Adapt /info and /health endpoints
code = _adapt_info_and_health_endpoints(code)

return code


Expand Down Expand Up @@ -448,6 +451,32 @@ async def _fetch_framework(framework: str) -> None:
return code.replace(sync_snippet, async_snippet)


def _adapt_info_and_health_endpoints(code: str) -> str:
info_sync_snippet = """
response = get_session().get(url, headers=self.headers)
hf_raise_for_status(response)
return response.json()"""

info_async_snippet = """
async with _import_aiohttp().ClientSession(headers=self.headers) as client:
response = await client.get(url)
response.raise_for_status()
return await response.json()"""

code = code.replace(info_sync_snippet, info_async_snippet)

health_sync_snippet = """
response = get_session().get(url, headers=self.headers)
return response.status_code == 200"""

health_async_snippet = """
async with _import_aiohttp().ClientSession(headers=self.headers) as client:
response = await client.get(url)
return response.status == 200"""

return code.replace(health_sync_snippet, health_async_snippet)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down

0 comments on commit 8873deb

Please sign in to comment.