diff --git a/inventree/api.py b/inventree/api.py index d749c92..7db2782 100644 --- a/inventree/api.py +++ b/inventree/api.py @@ -14,6 +14,8 @@ from requests.auth import HTTPBasicAuth from requests.exceptions import Timeout +from . import oAuthClient as oauth + logger = logging.getLogger('inventree') @@ -45,6 +47,9 @@ def __init__(self, host=None, **kwargs): token - Authentication token (if provided, username/password are ignored) token-name - Name of the token to use (default = 'inventree-python-client') use_token_auth - Use token authentication? (default = True) + use_oidc_auth - Use OIDC authentication? (default = False) + oidc_client_id - OIDC client ID (defaults to InvenTree public client) + oidc_scopes - OIDC scopes (default = ['openid', 'g:read']) verbose - Print extra debug messages (default = False) strict - Enforce strict HTTPS certificate checking (default = True) timeout - Set timeout to use (in seconds). Default: 10 @@ -56,6 +61,9 @@ def __init__(self, host=None, **kwargs): INVENTREE_API_PASSWORD - Password INVENTREE_API_TOKEN - User access token INVENTREE_API_TIMEOUT - Timeout value, in seconds + INVENTREE_API_OIDC - Use OIDC + INVENTREE_API_OIDC_CLIENT_ID - OIDC client ID + INVENTREE_API_OIDC_SCOPES - OIDC scopes """ self.setHostName(host or os.environ.get('INVENTREE_API_HOST', None)) @@ -68,8 +76,13 @@ def __init__(self, host=None, **kwargs): self.timeout = kwargs.get('timeout', os.environ.get('INVENTREE_API_TIMEOUT', 10)) self.proxies = kwargs.get('proxies', dict()) self.strict = bool(kwargs.get('strict', True)) + self.oidc_client_id = kwargs.get('oidc_client_id', os.environ.get('INVENTREE_API_OIDC_CLIENT_ID', 'zDFnsiRheJIOKNx6aCQ0quBxECg1QBHtVFDPloJ6')) + self.oidc_scopes = kwargs.get('oidc_scopes', os.environ.get('INVENTREE_API_OIDC_SCOPES', ['openid', 'g:read'])) self.use_token_auth = kwargs.get('use_token_auth', True) + self.use_oidc_auth = kwargs.get('use_oidc_auth', os.environ.get('INVENTREE_API_OIDC', False)) + if self.use_oidc_auth and self.use_token_auth: + self.use_token_auth = False self.verbose = kwargs.get('verbose', False) self.auth = None @@ -126,15 +139,18 @@ def connect(self): except Exception: raise ConnectionRefusedError("Could not connect to InvenTree server") + if self.use_oidc_auth: + self.requestOidcToken() + return + # Basic authentication self.auth = HTTPBasicAuth(self.username, self.password) if not self.testAuth(): raise ConnectionError("Authentication at InvenTree server failed") - if self.use_token_auth: - if not self.token: - self.requestToken() + if self.use_token_auth and not self.token: + self.requestToken() def constructApiUrl(self, endpoint_url): """Construct an API endpoint URL based on the provided API URL. @@ -273,6 +289,13 @@ def requestToken(self): return self.token + def requestOidcToken(self): + """Return authentication token from the server using OIDC.""" + client = oauth.OAuthClient(self.base_url, self.oidc_client_id, self.oidc_scopes) + self.token = client._access_token + + return self.token + def request(self, api_url, **kwargs): """ Perform a URL request to the Inventree API """ @@ -319,6 +342,9 @@ def request(self, api_url, **kwargs): if self.use_token_auth and self.token: headers['AUTHORIZATION'] = f'Token {self.token}' auth = None + elif self.use_oidc_auth and self.token: + headers['AUTHORIZATION'] = f'Bearer {self.token}' + auth = None else: auth = self.auth @@ -579,8 +605,9 @@ def downloadFile(self, url, destination, overwrite=False, params=None, proxies=d raise FileExistsError(f"Destination file '{destination}' already exists") if self.token: + headername = 'Token' if self.use_token_auth else 'Bearer' headers = { - 'AUTHORIZATION': f"Token {self.token}" + 'AUTHORIZATION': f"{headername} {self.token}" } auth = None else: diff --git a/inventree/oAuthClient.py b/inventree/oAuthClient.py new file mode 100644 index 0000000..dcd39c6 --- /dev/null +++ b/inventree/oAuthClient.py @@ -0,0 +1,105 @@ +import os +import urllib.parse as urlparse +import webbrowser +from http.server import BaseHTTPRequestHandler, HTTPServer + +from requests_oauthlib import OAuth2Session + +# Environment setup +os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1" +USABLE_PORT_RANGE = (29170, 292180) + + +class OAuthClient: + def __init__(self, server_url: str = "http://localhost:8000", client_id: str = '', scopes: list = None) -> None: + self.server_url = server_url + self.client_id = client_id + self.scopes = scopes if scopes is not None else [] + + self._handler_wrapper = RequestHandlerWrapper(self) + self._setup_callback() + self._poll_user() + + def get_url(self, path: str) -> str: + """Get the authorization URL.""" + return urlparse.urljoin(self.server_url, path) + + def _setup_callback(self): + for port in range(*USABLE_PORT_RANGE): + try: + self.server = HTTPServer(("127.0.0.1", port), self._handler_wrapper.request_handler) + self._port = port + break + except OSError: + continue + else: + raise Exception("No port found.") + + def _poll_user(self): + self._session = OAuth2Session( + self.client_id, scope=self.scopes, redirect_uri=f"http://localhost:{self._port}", pkce="S256" + ) + auth_url, state = self._session.authorization_url(self.get_url('/o/authorize/'), access_type="offline") + self._state = state + webbrowser.open_new_tab(auth_url) + + while not self._handler_wrapper.done: + self.server.handle_request() + if self._handler_wrapper.error: + raise Exception(self._handler_wrapper.error) + + def callback(self, callback_url: str): + self._session.fetch_token(self.get_url("/o/token/"), authorization_response=callback_url, include_client_id=True) + self._access_token = self._session.access_token + + +class RequestHandlerWrapper: + """Provides callback for OIDC endpoint.""" + def __init__(self, oauth_client) -> None: + self.done = False + self.error = None + self.client: OAuthClient = oauth_client + + @property + def request_handler(self): + wrapper = self + + class RequestHandler(BaseHTTPRequestHandler): + def do_GET(self): + parsed_url = urlparse.urlparse(self.path) + if parsed_url.path == "/": + error = urlparse.parse_qs(parsed_url.query).get("error", [None])[0] + if error: + wrapper.error = error + self.send(200) + else: + try: + wrapper.client.callback(self.path) + except OAuthError as e: + wrapper.error = e.message + self.send(400) + else: + self.send(200, 'Success! You can close this window.') + wrapper.done = True + else: + self.send(404) + + def send(self, status_code, content=None): + self.send_response(status_code) + if content: + self.wfile.write(content.encode("utf-8")) + else: + self.wfile.write(b"") + self.send_header("Content-type", "text/html") + self.end_headers() + + def log_message(self, *args): + pass # Suppress logging + + return RequestHandler + + +class OAuthError(Exception): + """Exception raised during the OAuth process.""" + def __init__(self, message: str) -> None: + self.message = message diff --git a/pyproject.toml b/pyproject.toml index fc346c4..2a5a454 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "pip-system-certs>=4.0", "requests>=2.27.0", "urllib3>=2.3.0", + "requests-oauthlib", ] [project.urls] diff --git a/requirements.txt b/requirements.txt index 4f1a39d..d30494c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ invoke>=1.4.0 coverage>=6.4.1 # Run tests, measure coverage coveralls>=3.3.1 Pillow>=9.1.1 +requests-oauthlib # Modern auth experience \ No newline at end of file