diff --git a/replicate/training.py b/replicate/training.py index 1f20660..8125cdf 100644 --- a/replicate/training.py +++ b/replicate/training.py @@ -418,4 +418,14 @@ def _create_training_url_from_model_and_version( def _json_to_training(client: "Client", json: Dict[str, Any]) -> Training: training = Training(**json) training._client = client + + # FIXME: This should be populated by the API + if ( + training.output + and isinstance(training.output, dict) + and "version" in training.output + ): + id = ModelVersionIdentifier.parse(training.output["version"]) + training.destination = f"{id.owner}/{id.name}" + return training diff --git a/tests/test_training.py b/tests/test_training.py index 64926c6..930ca6a 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -1,4 +1,6 @@ +import httpx import pytest +import respx import replicate from replicate.exceptions import ReplicateException @@ -188,3 +190,54 @@ async def test_trainings_cancel_instance_method(async_flag, mock_replicate_api_t training.cancel() assert training.status == "canceled" + + +router = respx.Router(base_url="https://api.replicate.com/v1") + +router.route( + method="GET", + path="/trainings/zz4ibbonubfz7carwiefibzgga", + name="trainings.get", +).mock( + return_value=httpx.Response( + 201, + json={ + "completed_at": "2023-09-08T16:41:19.826523Z", + "created_at": "2023-09-08T16:32:57.018467Z", + "error": None, + "id": "zz4ibbonubfz7carwiefibzgga", + "input": {"input_images": "https://example.com/my-input-images.zip"}, + "logs": "...", + "metrics": {"predict_time": 502.713876}, + "output": { + "version": "replicate/my-app-image-generator:8a43525956ef4039702e509c789964a7ea873697be9033abf9fd2badfe68c9e3", + "weights": "https://weights.replicate.com/example.tar", + }, + "started_at": "2023-09-08T16:32:57.112647Z", + "status": "succeeded", + "urls": { + "get": "https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga", + "cancel": "https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga/cancel", + }, + "model": "stability-ai/sdxl", + "version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf", + }, + ) +) + +router.route(host="api.replicate.com").pass_through() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_training_gets_destination_from_output(async_flag): + client = replicate.Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + + if async_flag: + training = await client.trainings.async_get("zz4ibbonubfz7carwiefibzgga") + else: + training = client.trainings.get("zz4ibbonubfz7carwiefibzgga") + + assert training.destination == "replicate/my-app-image-generator"