Skip to content

Commit

Permalink
rename host/url to uri; fetch to status
Browse files Browse the repository at this point in the history
  • Loading branch information
leondz committed Sep 23, 2024
1 parent e2fbd31 commit 2b9f682
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 15 deletions.
2 changes: 1 addition & 1 deletion docs/source/garak.generators.nemo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Configurable values:
* beam_width: 1
* length_penalty: 1
* guardrail: None - (present in API but not implemented in library)
* api_host: "https://api.llm.ngc.nvidia.com/v1" - endpoint URI
* api_uri: "https://api.llm.ngc.nvidia.com/v1" - endpoint URI



Expand Down
8 changes: 4 additions & 4 deletions docs/source/garak.generators.nvcf.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ and flexible generation.

NVCF functions work by sending a request to an invocation endpoint, and then polling
a status endpoint until the response is received. The cloud function is described
using a UUID, which is passed to garak as the model_name. API key should be placed in
environment variable NVCF_API_KEY or set in a garak config. For example:
using a UUID, which is passed to garak as the ``model_name``. API key should be placed in
environment variable ``NVCF_API_KEY`` or set in a garak config. For example:

.. code-block::
Expand All @@ -22,8 +22,8 @@ Configurable values:

* temperature - Temperature for generation. Passed as a value to the endpoint.
* top_p - Number of tokens to sample. Passed as a value to the endpoint.
* invoke_url_base - Base URL for the NVCF endpoint (default is for NVIDIA-hosted functions).
* fetch_url_format - URL to check for request status updates (default is for NVIDIA-hosted functions).
* invoke_uri_base - Base URL for the NVCF endpoint (default is for NVIDIA-hosted functions).
* status_uri_base - URL to check for request status updates (default is for NVIDIA-hosted functions).
* timeout - Read timeout for HTTP requests (note, this is network timeout, distinct from inference timeout)
* version_id - API version id, postpended to endpoint URLs if supplied
* stop_on_404 - Give up on endpoints returning 404 (i.e. nonexistent ones)
Expand Down
4 changes: 2 additions & 2 deletions garak/generators/nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class NeMoGenerator(Generator):
"beam_width": 1,
"length_penalty": 1,
"guardrail": None, # NotImplemented in library
"api_host": "https://api.llm.ngc.nvidia.com/v1",
"api_uri": "https://api.llm.ngc.nvidia.com/v1",
}

supports_multiple_generations = False
Expand All @@ -48,7 +48,7 @@ def __init__(self, name=None, config_root=_config):
super().__init__(self.name, config_root=config_root)

self.nemo = nemollm.api.NemoLLM(
api_host=self.api_host, api_key=self.api_key, org_id=self.org_id
api_host=self.api_uri, api_key=self.api_key, org_id=self.org_id
)

if self.name is None:
Expand Down
14 changes: 7 additions & 7 deletions garak/generators/nvcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ class NvcfChat(Generator):
DEFAULT_PARAMS = Generator.DEFAULT_PARAMS | {
"temperature": 0.2,
"top_p": 0.7,
"fetch_url_format": "https://api.nvcf.nvidia.com/v2/nvcf/pexec/status/",
"invoke_url_base": "https://api.nvcf.nvidia.com/v2/nvcf/pexec/functions/",
"status_uri_base": "https://api.nvcf.nvidia.com/v2/nvcf/pexec/status/",
"invoke_uri_base": "https://api.nvcf.nvidia.com/v2/nvcf/pexec/functions/",
"timeout": 60,
"version_id": None, # string
"stop_on_404": True,
Expand All @@ -49,10 +49,10 @@ def __init__(self, name=None, config_root=_config):
"Please specify a function identifier in model name (-n)"
)

self.invoke_url = self.invoke_url_base + self.name
self.invoke_uri = self.invoke_uri_base + self.name

if self.version_id is not None:
self.invoke_url += f"/versions/{self.version_id}"
self.invoke_uri += f"/versions/{self.version_id}"

super().__init__(self.name, config_root=config_root)

Expand Down Expand Up @@ -109,7 +109,7 @@ def _call_model(

request_time = time.time()
logging.debug("nvcf : payload %s", repr(payload))
response = session.post(self.invoke_url, headers=self.headers, json=payload)
response = session.post(self.invoke_uri, headers=self.headers, json=payload)

while response.status_code == 202:
if time.time() > request_time + self.timeout:
Expand All @@ -119,8 +119,8 @@ def _call_model(
msg = "Got HTTP 202 but no NVCF-REQID was returned"
logging.info("nvcf : %s", msg)
raise AttributeError(msg)
fetch_url = self.fetch_url_format + request_id
response = session.get(fetch_url, headers=self.headers)
status_uri = self.status_uri_base + request_id
response = session.get(status_uri, headers=self.headers)

if 400 <= response.status_code < 600:
logging.warning("nvcf : returned error code %s", response.status_code)
Expand Down
2 changes: 1 addition & 1 deletion tests/generators/test_nvcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_version_endpoint(klassname):
_config.plugins.generators["nvcf"][klassname]["api_key"] = "placeholder key"
_config.plugins.generators["nvcf"][klassname]["version_id"] = version
g = _plugins.load_plugin(f"generators.nvcf.{klassname}")
assert g.invoke_url == f"{g.invoke_url_base}{name}/versions/{version}"
assert g.invoke_uri == f"{g.invoke_uri_base}{name}/versions/{version}"


@pytest.mark.parametrize("klassname", PLUGINS)
Expand Down

0 comments on commit 2b9f682

Please sign in to comment.