Skip to content

Commit 09887c1

Browse files
Updating jigsawstack.embedding to accept both blob and params (type: literal, required).
1 parent 08b4781 commit 09887c1

File tree

1 file changed

+58
-7
lines changed

1 file changed

+58
-7
lines changed

jigsawstack/embedding.py

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
from typing import Any, Dict, List, Union, cast, Literal
1+
from typing import Any, Dict, List, Union, cast, Literal, overload
22
from typing_extensions import NotRequired, TypedDict
33
from .request import Request, RequestConfig
44
from .async_request import AsyncRequest
55
from typing import List, Union
66
from ._config import ClientConfig
7+
from .helpers import build_path
78

89

910
class EmbeddingParams(TypedDict):
@@ -38,12 +39,37 @@ def __init__(
3839
disable_request_logging=disable_request_logging,
3940
)
4041

41-
def execute(self, params: EmbeddingParams) -> EmbeddingResponse:
42-
path = "/embedding"
42+
@overload
43+
def execute(self, params: EmbeddingParams) -> EmbeddingResponse: ...
44+
@overload
45+
def execute(self, file: bytes, options: EmbeddingParams = None) -> EmbeddingResponse: ...
46+
47+
def execute(
48+
self,
49+
blob: Union[EmbeddingParams, bytes],
50+
options: EmbeddingParams = None,
51+
) -> EmbeddingResponse:
52+
path="/embedding"
53+
if isinstance(blob, dict):
54+
resp = Request(
55+
config=self.config,
56+
path=path,
57+
params=cast(Dict[Any, Any], blob),
58+
verb="post",
59+
).perform_with_content()
60+
return resp
61+
62+
options = options or {}
63+
path = build_path(base_path=path, params=options)
64+
content_type = options.get("content_type", "application/octet-stream")
65+
_headers = {"Content-Type": content_type}
66+
4367
resp = Request(
4468
config=self.config,
4569
path=path,
46-
params=cast(Dict[Any, Any], params),
70+
params=options,
71+
data=blob,
72+
headers=_headers,
4773
verb="post",
4874
).perform_with_content()
4975
return resp
@@ -66,12 +92,37 @@ def __init__(
6692
disable_request_logging=disable_request_logging,
6793
)
6894

69-
async def execute(self, params: EmbeddingParams) -> EmbeddingResponse:
70-
path = "/embedding"
95+
@overload
96+
async def execute(self, params: EmbeddingParams) -> EmbeddingResponse: ...
97+
@overload
98+
async def execute(self, file: bytes, options: EmbeddingParams = None) -> EmbeddingResponse: ...
99+
100+
async def execute(
101+
self,
102+
blob: Union[EmbeddingParams, bytes],
103+
options: EmbeddingParams = None,
104+
) -> EmbeddingResponse:
105+
path="/embedding"
106+
if isinstance(blob, dict):
107+
resp = await AsyncRequest(
108+
config=self.config,
109+
path=path,
110+
params=cast(Dict[Any, Any], blob),
111+
verb="post",
112+
).perform_with_content()
113+
return resp
114+
115+
options = options or {}
116+
path = build_path(base_path=path, params=options)
117+
content_type = options.get("content_type", "application/octet-stream")
118+
_headers = {"Content-Type": content_type}
119+
71120
resp = await AsyncRequest(
72121
config=self.config,
73122
path=path,
74-
params=cast(Dict[Any, Any], params),
123+
params=options,
124+
data=blob,
125+
headers=_headers,
75126
verb="post",
76127
).perform_with_content()
77128
return resp

0 commit comments

Comments
 (0)