1
- from typing import Any , Dict , List , Union , cast , Literal
1
+ from typing import Any , Dict , List , Union , cast , Literal , overload
2
2
from typing_extensions import NotRequired , TypedDict
3
3
from .request import Request , RequestConfig
4
4
from .async_request import AsyncRequest
5
5
from typing import List , Union
6
6
from ._config import ClientConfig
7
+ from .helpers import build_path
7
8
8
9
9
10
class EmbeddingParams (TypedDict ):
@@ -38,12 +39,37 @@ def __init__(
38
39
disable_request_logging = disable_request_logging ,
39
40
)
40
41
41
- def execute (self , params : EmbeddingParams ) -> EmbeddingResponse :
42
- path = "/embedding"
42
+ @overload
43
+ def execute (self , params : EmbeddingParams ) -> EmbeddingResponse : ...
44
+ @overload
45
+ def execute (self , file : bytes , options : EmbeddingParams = None ) -> EmbeddingResponse : ...
46
+
47
+ def execute (
48
+ self ,
49
+ blob : Union [EmbeddingParams , bytes ],
50
+ options : EmbeddingParams = None ,
51
+ ) -> EmbeddingResponse :
52
+ path = "/embedding"
53
+ if isinstance (blob , dict ):
54
+ resp = Request (
55
+ config = self .config ,
56
+ path = path ,
57
+ params = cast (Dict [Any , Any ], blob ),
58
+ verb = "post" ,
59
+ ).perform_with_content ()
60
+ return resp
61
+
62
+ options = options or {}
63
+ path = build_path (base_path = path , params = options )
64
+ content_type = options .get ("content_type" , "application/octet-stream" )
65
+ _headers = {"Content-Type" : content_type }
66
+
43
67
resp = Request (
44
68
config = self .config ,
45
69
path = path ,
46
- params = cast (Dict [Any , Any ], params ),
70
+ params = options ,
71
+ data = blob ,
72
+ headers = _headers ,
47
73
verb = "post" ,
48
74
).perform_with_content ()
49
75
return resp
@@ -66,12 +92,37 @@ def __init__(
66
92
disable_request_logging = disable_request_logging ,
67
93
)
68
94
69
- async def execute (self , params : EmbeddingParams ) -> EmbeddingResponse :
70
- path = "/embedding"
95
+ @overload
96
+ async def execute (self , params : EmbeddingParams ) -> EmbeddingResponse : ...
97
+ @overload
98
+ async def execute (self , file : bytes , options : EmbeddingParams = None ) -> EmbeddingResponse : ...
99
+
100
+ async def execute (
101
+ self ,
102
+ blob : Union [EmbeddingParams , bytes ],
103
+ options : EmbeddingParams = None ,
104
+ ) -> EmbeddingResponse :
105
+ path = "/embedding"
106
+ if isinstance (blob , dict ):
107
+ resp = await AsyncRequest (
108
+ config = self .config ,
109
+ path = path ,
110
+ params = cast (Dict [Any , Any ], blob ),
111
+ verb = "post" ,
112
+ ).perform_with_content ()
113
+ return resp
114
+
115
+ options = options or {}
116
+ path = build_path (base_path = path , params = options )
117
+ content_type = options .get ("content_type" , "application/octet-stream" )
118
+ _headers = {"Content-Type" : content_type }
119
+
71
120
resp = await AsyncRequest (
72
121
config = self .config ,
73
122
path = path ,
74
- params = cast (Dict [Any , Any ], params ),
123
+ params = options ,
124
+ data = blob ,
125
+ headers = _headers ,
75
126
verb = "post" ,
76
127
).perform_with_content ()
77
128
return resp
0 commit comments