Skip to content

Commit 08b4781

Browse files
Updating translate to accept both binary data and params with overloading.
1 parent 5ebe84d commit 08b4781

File tree

2 files changed

+66
-34
lines changed

2 files changed

+66
-34
lines changed

jigsawstack/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def __init__(
100100
api_key=api_key,
101101
api_url=api_url,
102102
disable_request_logging=disable_request_logging,
103-
).translate
103+
)
104104

105105
self.prompt_engine = PromptEngine(
106106
api_key=api_key,
@@ -209,7 +209,7 @@ def __init__(
209209
api_key=api_key,
210210
api_url=api_url,
211211
disable_request_logging=disable_request_logging,
212-
).translate
212+
)
213213

214214
self.prompt_engine = AsyncPromptEngine(
215215
api_key=api_key,

jigsawstack/translate.py

Lines changed: 64 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1-
from typing import Any, Dict, List, Union, cast
1+
from typing import Any, Dict, List, Union, cast, overload
22
from typing_extensions import NotRequired, TypedDict
33
from .request import Request, RequestConfig
44
from .async_request import AsyncRequest
55
from typing import List, Union
66
from ._config import ClientConfig
7-
7+
from .helpers import build_path
88

99
class TranslateImageParams(TypedDict):
1010
target_language: str
1111
"""
1212
Target langauge to translate to.
1313
"""
14-
url: str
14+
url: NotRequired[str]
1515
"""
1616
The URL of the image to translate.
1717
"""
@@ -82,7 +82,7 @@ def __init__(
8282
disable_request_logging=disable_request_logging,
8383
)
8484

85-
def translate_text(
85+
def text(
8686
self, params: TranslateParams
8787
) -> Union[TranslateResponse, TranslateListResponse]:
8888
resp = Request(
@@ -92,25 +92,41 @@ def translate_text(
9292
verb="post",
9393
).perform()
9494
return resp
95+
96+
@overload
97+
def image(self, params: TranslateImageParams) -> TranslateImageResponse: ...
98+
@overload
99+
def image(self, file: bytes, options: TranslateImageParams = None) -> TranslateImageParams: ...
100+
101+
def image(
102+
self,
103+
blob: Union[TranslateImageParams, bytes],
104+
options: TranslateImageParams = None,
105+
) -> TranslateImageResponse:
106+
if isinstance(blob, dict): # If params is provided as a dict, we assume it's the first argument
107+
resp = Request(
108+
config=self.config,
109+
path="/ai/translate/image",
110+
params=cast(Dict[Any, Any], blob),
111+
verb="post",
112+
).perform_with_content()
113+
return resp
114+
115+
options = options or {}
116+
path = build_path(base_path="/ai/translate/image", params=options)
117+
content_type = options.get("content_type", "application/octet-stream")
118+
headers = {"Content-Type": content_type}
95119

96-
def translate_image(
97-
self, params: TranslateImageParams
98-
) -> TranslateImageResponse:
99120
resp = Request(
100121
config=self.config,
101-
path="/ai/translate/image",
102-
params=cast(Dict[Any, Any], params),
122+
path=path,
123+
params=options,
124+
data=blob,
125+
headers=headers,
103126
verb="post",
104-
).perform()
127+
).perform_with_content()
105128
return resp
106129

107-
def translate(
108-
self, params: Union[TranslateParams, TranslateImageParams]
109-
) -> Union[TranslateResponse, TranslateListResponse, TranslateImageResponse]:
110-
if "url" in params or "file_store_key" in params:
111-
return self.translate_image(params)
112-
return self.translate_text(params)
113-
114130

115131
class AsyncTranslate(ClientConfig):
116132
config: RequestConfig
@@ -128,7 +144,7 @@ def __init__(
128144
disable_request_logging=disable_request_logging,
129145
)
130146

131-
async def translate_text(
147+
async def text(
132148
self, params: TranslateParams
133149
) -> Union[TranslateResponse, TranslateListResponse]:
134150
resp = await AsyncRequest(
@@ -138,21 +154,37 @@ async def translate_text(
138154
verb="post",
139155
).perform()
140156
return resp
141-
142-
async def translate_image(
143-
self, params: TranslateImageParams
157+
158+
@overload
159+
async def image(self, params: TranslateImageParams) -> TranslateImageResponse: ...
160+
@overload
161+
async def image(self, file: bytes, options: TranslateImageParams = None) -> TranslateImageParams: ...
162+
163+
async def image(
164+
self,
165+
blob: Union[TranslateImageParams, bytes],
166+
options: TranslateImageParams = None,
144167
) -> TranslateImageResponse:
168+
if isinstance(blob, dict):
169+
resp = await AsyncRequest(
170+
config=self.config,
171+
path="/ai/translate/image",
172+
params=cast(Dict[Any, Any], blob),
173+
verb="post",
174+
).perform_with_content()
175+
return resp
176+
177+
options = options or {}
178+
path = build_path(base_path="/ai/translate/image", params=options)
179+
content_type = options.get("content_type", "application/octet-stream")
180+
headers = {"Content-Type": content_type}
181+
145182
resp = await AsyncRequest(
146183
config=self.config,
147-
path="/ai/translate/image",
148-
params=cast(Dict[Any, Any], params),
184+
path=path,
185+
params=options,
186+
data=blob,
187+
headers=headers,
149188
verb="post",
150-
).perform()
151-
return resp
152-
153-
async def translate(
154-
self, params: Union[TranslateParams, TranslateImageParams]
155-
) -> Union[TranslateResponse, TranslateListResponse, TranslateImageResponse]:
156-
if "url" in params or "file_store_key" in params:
157-
return await self.translate_image(params)
158-
return await self.translate_text(params)
189+
).perform_with_content()
190+
return resp

0 commit comments

Comments
 (0)