3
3
from __future__ import annotations
4
4
5
5
import os
6
- from typing import Any , Union , Mapping
7
- from typing_extensions import Self , override
6
+ from typing import Any , Dict , Union , Mapping , cast
7
+ from typing_extensions import Self , Literal , override
8
8
9
9
import httpx
10
10
30
30
AsyncAPIClient ,
31
31
)
32
32
33
- __all__ = ["Timeout" , "Transport" , "ProxiesTypes" , "RequestOptions" , "Kernel" , "AsyncKernel" , "Client" , "AsyncClient" ]
33
+ __all__ = [
34
+ "ENVIRONMENTS" ,
35
+ "Timeout" ,
36
+ "Transport" ,
37
+ "ProxiesTypes" ,
38
+ "RequestOptions" ,
39
+ "Kernel" ,
40
+ "AsyncKernel" ,
41
+ "Client" ,
42
+ "AsyncClient" ,
43
+ ]
44
+
45
+ ENVIRONMENTS : Dict [str , str ] = {
46
+ "production" : "https://api.onkernel.com/" ,
47
+ "development" : "https://localhost:3001/" ,
48
+ }
34
49
35
50
36
51
class Kernel (SyncAPIClient ):
@@ -42,11 +57,14 @@ class Kernel(SyncAPIClient):
42
57
# client options
43
58
api_key : str
44
59
60
+ _environment : Literal ["production" , "development" ] | NotGiven
61
+
45
62
def __init__ (
46
63
self ,
47
64
* ,
48
65
api_key : str | None = None ,
49
- base_url : str | httpx .URL | None = None ,
66
+ environment : Literal ["production" , "development" ] | NotGiven = NOT_GIVEN ,
67
+ base_url : str | httpx .URL | None | NotGiven = NOT_GIVEN ,
50
68
timeout : Union [float , Timeout , None , NotGiven ] = NOT_GIVEN ,
51
69
max_retries : int = DEFAULT_MAX_RETRIES ,
52
70
default_headers : Mapping [str , str ] | None = None ,
@@ -77,10 +95,31 @@ def __init__(
77
95
)
78
96
self .api_key = api_key
79
97
80
- if base_url is None :
81
- base_url = os .environ .get ("KERNEL_BASE_URL" )
82
- if base_url is None :
83
- base_url = f"http://localhost:3001"
98
+ self ._environment = environment
99
+
100
+ base_url_env = os .environ .get ("KERNEL_BASE_URL" )
101
+ if is_given (base_url ) and base_url is not None :
102
+ # cast required because mypy doesn't understand the type narrowing
103
+ base_url = cast ("str | httpx.URL" , base_url ) # pyright: ignore[reportUnnecessaryCast]
104
+ elif is_given (environment ):
105
+ if base_url_env and base_url is not None :
106
+ raise ValueError (
107
+ "Ambiguous URL; The `KERNEL_BASE_URL` env var and the `environment` argument are given. If you want to use the environment, you must pass base_url=None" ,
108
+ )
109
+
110
+ try :
111
+ base_url = ENVIRONMENTS [environment ]
112
+ except KeyError as exc :
113
+ raise ValueError (f"Unknown environment: { environment } " ) from exc
114
+ elif base_url_env is not None :
115
+ base_url = base_url_env
116
+ else :
117
+ self ._environment = environment = "production"
118
+
119
+ try :
120
+ base_url = ENVIRONMENTS [environment ]
121
+ except KeyError as exc :
122
+ raise ValueError (f"Unknown environment: { environment } " ) from exc
84
123
85
124
super ().__init__ (
86
125
version = __version__ ,
@@ -122,6 +161,7 @@ def copy(
122
161
self ,
123
162
* ,
124
163
api_key : str | None = None ,
164
+ environment : Literal ["production" , "development" ] | None = None ,
125
165
base_url : str | httpx .URL | None = None ,
126
166
timeout : float | Timeout | None | NotGiven = NOT_GIVEN ,
127
167
http_client : httpx .Client | None = None ,
@@ -157,6 +197,7 @@ def copy(
157
197
return self .__class__ (
158
198
api_key = api_key or self .api_key ,
159
199
base_url = base_url or self .base_url ,
200
+ environment = environment or self ._environment ,
160
201
timeout = self .timeout if isinstance (timeout , NotGiven ) else timeout ,
161
202
http_client = http_client ,
162
203
max_retries = max_retries if is_given (max_retries ) else self .max_retries ,
@@ -212,11 +253,14 @@ class AsyncKernel(AsyncAPIClient):
212
253
# client options
213
254
api_key : str
214
255
256
+ _environment : Literal ["production" , "development" ] | NotGiven
257
+
215
258
def __init__ (
216
259
self ,
217
260
* ,
218
261
api_key : str | None = None ,
219
- base_url : str | httpx .URL | None = None ,
262
+ environment : Literal ["production" , "development" ] | NotGiven = NOT_GIVEN ,
263
+ base_url : str | httpx .URL | None | NotGiven = NOT_GIVEN ,
220
264
timeout : Union [float , Timeout , None , NotGiven ] = NOT_GIVEN ,
221
265
max_retries : int = DEFAULT_MAX_RETRIES ,
222
266
default_headers : Mapping [str , str ] | None = None ,
@@ -247,10 +291,31 @@ def __init__(
247
291
)
248
292
self .api_key = api_key
249
293
250
- if base_url is None :
251
- base_url = os .environ .get ("KERNEL_BASE_URL" )
252
- if base_url is None :
253
- base_url = f"http://localhost:3001"
294
+ self ._environment = environment
295
+
296
+ base_url_env = os .environ .get ("KERNEL_BASE_URL" )
297
+ if is_given (base_url ) and base_url is not None :
298
+ # cast required because mypy doesn't understand the type narrowing
299
+ base_url = cast ("str | httpx.URL" , base_url ) # pyright: ignore[reportUnnecessaryCast]
300
+ elif is_given (environment ):
301
+ if base_url_env and base_url is not None :
302
+ raise ValueError (
303
+ "Ambiguous URL; The `KERNEL_BASE_URL` env var and the `environment` argument are given. If you want to use the environment, you must pass base_url=None" ,
304
+ )
305
+
306
+ try :
307
+ base_url = ENVIRONMENTS [environment ]
308
+ except KeyError as exc :
309
+ raise ValueError (f"Unknown environment: { environment } " ) from exc
310
+ elif base_url_env is not None :
311
+ base_url = base_url_env
312
+ else :
313
+ self ._environment = environment = "production"
314
+
315
+ try :
316
+ base_url = ENVIRONMENTS [environment ]
317
+ except KeyError as exc :
318
+ raise ValueError (f"Unknown environment: { environment } " ) from exc
254
319
255
320
super ().__init__ (
256
321
version = __version__ ,
@@ -292,6 +357,7 @@ def copy(
292
357
self ,
293
358
* ,
294
359
api_key : str | None = None ,
360
+ environment : Literal ["production" , "development" ] | None = None ,
295
361
base_url : str | httpx .URL | None = None ,
296
362
timeout : float | Timeout | None | NotGiven = NOT_GIVEN ,
297
363
http_client : httpx .AsyncClient | None = None ,
@@ -327,6 +393,7 @@ def copy(
327
393
return self .__class__ (
328
394
api_key = api_key or self .api_key ,
329
395
base_url = base_url or self .base_url ,
396
+ environment = environment or self ._environment ,
330
397
timeout = self .timeout if isinstance (timeout , NotGiven ) else timeout ,
331
398
http_client = http_client ,
332
399
max_retries = max_retries if is_given (max_retries ) else self .max_retries ,
0 commit comments