diff --git a/localhost.crt b/localhost.crt new file mode 100644 index 0000000..2705484 --- /dev/null +++ b/localhost.crt @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDDzCCAfegAwIBAgIUeheYrRpad428Bo7n3O4iSeDvl/wwDQYJKoZIhvcNAQEL +BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTI1MDIxNzE3MzExOFoXDTI1MDMx +OTE3MzExOFowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF +AAOCAQ8AMIIBCgKCAQEA14LBVxD5j3Rs0lqANYUoNHYPBGII1Z9Lh1SjEyapI857 +JUYZm9VwVLqZJmcv8ppQeyE6nQOI+EvmMC7zy2w0us4hmayH93XEjyBgFv9KNifG +csLH9FDwJmBK1pyIoSgWE82GjuDJOA/ociuFoi5TjNgL1F4DpS7zOMA+OJcOW9B3 +cWMT4F7oGLfnhrd+J1k/vsDGU047hivspFXKNBA1HpVyHC/nAIKUNZZkCaR0NjCI +zSGxQ19f2uG59J1jizlruk3y7UlvFskSGUocvRiNzPgsMpWplxwszqv0eEYV/Ber +HjaLnanEcF7zGEuGc2oMWziTmLPBQF4nKOlJGJsACwIDAQABo1kwVzAUBgNVHREE +DTALgglsb2NhbGhvc3QwCwYDVR0PBAQDAgeAMBMGA1UdJQQMMAoGCCsGAQUFBwMB +MB0GA1UdDgQWBBRI2ZH2Clz43NSxOxxXxm7tBaUGdzANBgkqhkiG9w0BAQsFAAOC +AQEAryDWVc6VXTn4rV/BXdxLMj//0aqpzmjY8OE+LElqEijvUPlM5RJCteK/bi87 +C8t02yDmaWnRuoblguusJayVDBSh3bxG6QjT+ijjTgOVDs30aOvTp3mxb34QYhox ++Mwg9IDrmqoSsXd9uNToBV4639IRmGSE/4yz/OTq+F1tlGpi8REhxIbmyN3ODExK +Uyp1wALTAR4koErqzruQ7vh7oWMNU3Mb6w6SizCe9K2oqMCQ8zCHSHKnP/eQn4Bo +6dzvX7HrTsJfbQKosMrDJQhi6QXZBE7DRzbY4vYpQ8E3HqvoAv9jBd1KrkXPMpN0 +ORR43t6HviCz9puPh0kMkdGRIg== +-----END CERTIFICATE----- diff --git a/localhost.key b/localhost.key new file mode 100644 index 0000000..3f70b28 --- /dev/null +++ b/localhost.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDXgsFXEPmPdGzS +WoA1hSg0dg8EYgjVn0uHVKMTJqkjznslRhmb1XBUupkmZy/ymlB7ITqdA4j4S+Yw +LvPLbDS6ziGZrIf3dcSPIGAW/0o2J8Zywsf0UPAmYErWnIihKBYTzYaO4Mk4D+hy +K4WiLlOM2AvUXgOlLvM4wD44lw5b0HdxYxPgXugYt+eGt34nWT++wMZTTjuGK+yk +Vco0EDUelXIcL+cAgpQ1lmQJpHQ2MIjNIbFDX1/a4bn0nWOLOWu6TfLtSW8WyRIZ +Shy9GI3M+CwylamXHCzOq/R4RhX8F6seNoudqcRwXvMYS4ZzagxbOJOYs8FAXico +6UkYmwALAgMBAAECggEACf+SebgbY+jqXY3+Ub1WQqzRgIoNz99ekS4/jJVoFnWv +Z+jLKlwqJHwtu8bgxhgbsMK3Ze5yjdZznPuoquDfx2Tl0SvceQIZNuyxGJAKgN2y +isN3pGGW6qjf//nuKs/hylRoMDvEihnO1nEnd4E/thKV5engsGqvtQvSNyzm6SJ6 +d/u5uFK7VQrWAkJ1+xCSTPPWi77CiMBh/qMY2Y6Ph0wYexQdNmS8np1cN7c5+N1b +evLSQAIiNUhXmVYY31z6oWHq+CYM/CRyWgDfQHFlUoBisIizzMptvDcvFi+7Rp22 +qp1Lm6KNOPG7/nq/zDNSBfd1FqlaR/TcQDMtVUKaCQKBgQD7S+/NovCBnkhiEciU +GHGIvpfoyXO/UfiGKcf4WZCKfqlz9RNvzmuekRK1GAJzRnsO6HBeNwPasTHFIBJd +WZIJ4nJS681R/TTLfIPveKuo99F1iB2mZHDX9dkQS3ZlK7OuudVuE0xTcUoNAD0P +9zWcSoBrLrEToHYldtbq+FW8swKBgQDbi1qhNduMV/wq5SuEAntSraKdKcJho47R +MkNu3grg5sNNaMbQJhiO5KHBdfPefgQAqp0ZuYSMPg8GPnIrKN4nFegkBUWgjoi9 +zAgfmdSaExo8gNe3S5J9wSEVpskVnDtLjS0FYcTOCLnNneiim8TGLKHXUpvWOTdt +wcfPCgGLSQKBgBFlk2dgBVhj1cz8QC+IdauqziductXm3daj49UcljYQSLjfWYYe ++zJSBsKEs/64/WHt04GiO2ETbUehTcQqpEKM668z5dXsOpBvwU59wxyCc3y4fJz9 +TRaWTX2kS8D7QogxE0Z4jYslR6QYxSFq0spMGhHRfK7IKAW18XD42i6jAoGAMAZL +zPf7DrgwcTGwUzA3yd4xtC9uVe1xUFGubpIjzw6rqkNBOkcbGCbrO2aR8hmexoaL +1xS96e+pWbRPRSGrduFT5o1Ard6ACwSWwlLkLs/+7T1B8taVNO0KT7IsSo3iaqR3 +NLYuVuORwWjJesiYQsGApZlsfXAGr/uzuZZ2wAECgYEAgB9QWsVzWHhbJ1YNK9qE +igWV4src1lRcwq3pN5iYLT6NXMvLZjPHOtEz5PRO4M85SwDlOG+82BWmDMNgFwoq +W+BC+xjRbqywT/7Acj0zS1GUm2gPqPt606b/H8L4WjENJO23C1iEF/brOb6GJKxy +WBfnsVRqocHJ/hHuLIdBvMQ= +-----END PRIVATE KEY----- diff --git a/packages/auth0-ai/README.md b/packages/auth0-ai/README.md new file mode 100644 index 0000000..dfe3b5f --- /dev/null +++ b/packages/auth0-ai/README.md @@ -0,0 +1,82 @@ +# Auth0 AI + +This package provides base methods to use Auth0 with your AI use cases. + +## Installation + +```bash +pip install git+https://github.com/mustafadeel/auth0-ai-python.git@mod_struct#subdirectory=packages/auth0-ai +``` + +## Running Tests + +1. **Install Dependencies** + + Use [Poetry](https://python-poetry.org/) to install the required dependencies: + + ```sh + $ poetry install + ``` + +2. **Run the tests** + + ```sh + $ poetry run pytest tests + ``` + +## Usage + +Create a .env file with the following deatils: + +``` +AUTH0_DOMAIN='<>' +AUTH0_CLIENT_ID='<>' +AUTH0_CLIENT_SECRET='<>' +AUTH0_REDIRECT_URI='<>' +AUTH0_SECRET_KEY='ALongRandomlyGeneratedString' +``` + +Create a python script for an interactive login, link and tool token example: + +```python +from dotenv import find_dotenv, load_dotenv + +import asyncio + +from auth0_ai import AIAuth, User + +ENV_FILE = find_dotenv() +if ENV_FILE: + load_dotenv(ENV_FILE) + +auth_client = AIAuth() + +async def login(): + return await auth_client.interactive_login(connection="Username-Password-Authentication", scope="openid email offline_access") + +async def link(user_id, connection): + linked = await auth_client.link(primary_user_id=user_id, connection=connection, scope="openid email") + return linked + +user1 = asyncio.run(login()) + +print("-" * 20) +print("USER DETAILS:", auth_client.get_session(user1)) + +link_status = asyncio.run(user1.link(connection="github")) + +github_token = user1.get_3rd_party_token("github") +``` + +--- + +

+ + + + Auth0 Logo + +

+

Auth0 is an easy to implement, adaptable authentication and authorization platform. To learn more checkout Why Auth0?

+

+This project is licensed under the Apache 2.0 license. See the LICENSE file for more info.

diff --git a/packages/auth0-ai/auth0_ai/__init__.py b/packages/auth0-ai/auth0_ai/__init__.py new file mode 100644 index 0000000..874d9eb --- /dev/null +++ b/packages/auth0-ai/auth0_ai/__init__.py @@ -0,0 +1,3 @@ +from .ai_auth import AIAuth, User + +__all__ = ["AIAuth", "User"] diff --git a/packages/auth0-ai/auth0_ai/ai_auth.py b/packages/auth0-ai/auth0_ai/ai_auth.py new file mode 100644 index 0000000..e74f1f3 --- /dev/null +++ b/packages/auth0-ai/auth0_ai/ai_auth.py @@ -0,0 +1,738 @@ +from __future__ import annotations + +from typing import Any + +from auth0.authentication.base import AuthenticationBase +from auth0.authentication import GetToken +from auth0.authentication import RevokeToken +from auth0.authentication.async_token_verifier import AsyncAsymmetricSignatureVerifier +from auth0.authentication.pushed_authorization_requests import PushedAuthorizationRequests + +from .session_storage import SessionStorage + +import webbrowser +import urllib.parse + +import os + +import jwt # PyJWT for signing cookies +from fastapi import FastAPI, Request, HTTPException, Response +from fastapi.responses import JSONResponse, RedirectResponse + +from typing import Any, Dict + +import uvicorn +import secrets +import threading + +import time +import json + + +class AIAuth(AuthenticationBase): + + def __init__( + self, + domain: str | None = None, + client_id: str | None = None, + client_secret: str | None = None, + redirect_uri: str | None = None, + *args, **kwargs): + + self.domain = domain or os.environ.get("AUTH0_DOMAIN") + self.client_id = client_id or os.environ.get("AUTH0_CLIENT_ID") + self.client_secret = client_secret or os.environ.get( + "AUTH0_CLIENT_SECRET") + self.redirect_uri = redirect_uri or os.environ.get( + "AUTH0_REDIRECT_URI") + + """Initialize AIAuth and set up the middleware app with the callback route.""" + super().__init__( + domain=self.domain, + client_id=self.client_id, + client_secret=self.client_secret, + *args, **kwargs) # Initialize parent class + + @property + def domain(self): + return self._domain + + @domain.setter + def domain(self, val): + if not val: + raise ValueError( + "domain cannot be empty. you can also set AUTH0_DOMAIN value in .env file") + self._domain = val + + @property + def client_id(self): + return self._client_id + + @client_id.setter + def client_id(self, val): + if not val: + raise ValueError( + "client_id cannot be empty. you can also AUTH0_CLIENT_ID value in .env file") + self._client_id = val + + @property + def client_secret(self): + return self._client_secret + + @client_secret.setter + def client_secret(self, val): + if not val: + raise ValueError( + "client_secret cannot be empty. you can also AUTH0_CLIENT_SECRET value in .env file") + self._client_secret = val + + @property + def redirect_uri(self): + return self._redirect_uri + + @redirect_uri.setter + def redirect_uri(self, val): + if not val: + raise ValueError( + "redirect_uri cannot be empty. you can also AUTH0_REDIRECT_URI value in .env file") + self._redirect_uri = val + + # Initialize token verifier + jwk_url = f"https://{self.domain}/.well-known/jwks.json" + self.token_verifier = AsyncAsymmetricSignatureVerifier( + jwks_url=jwk_url) + + # Initialize FastAPI app + self.app = FastAPI() + # Temporary store for state values + self.state_store: Dict[str, Dict[bool, str]] = {} + # or secrets.token_urlsafe(32) # Secure random secret key + self.secret_key = os.environ.get("AUTH0_SECRET_KEY") + + # Initialize the SessionStore + self.session_store = SessionStorage() + + # Register the callback route + @self.app.get("/auth/callback") + async def manage_callback(request: Request, response: Response): + """Parses and validates callback URL query parameters.""" + query_params = request.query_params + required_keys = {"code", "state"} + + if query_params.get("error"): + error_description = query_params.get( + "error_description", "Unknown error occurred.") + if query_params.get("state"): + del self.state_store[query_params.get("state")] + raise HTTPException(status_code=400, detail=error_description) + + if not required_keys.issubset(query_params.keys()): + raise HTTPException( + status_code=400, detail="Missing required query parameters.") + + received_state = query_params["state"] + + # Validate state to prevent CSRF attacks + if received_state not in self.state_store: + raise HTTPException( + status_code=400, detail="Invalid or missing state parameter.") + + # Extract code value from query string + received_code = query_params["code"] + + auth0_tokens = self._exchange_code_for_tokens(received_code) + + if auth0_tokens: + + cookie_session_data = await self._set_encrypted_session(auth0_tokens, state=received_state) + + user_id = self.state_store[received_state].get( + "user_id", "failed") + return_to = self.state_store[received_state].get("return_to", None) + self.state_store[received_state] = { + "user_id": user_id, "is_completed": True} + + if return_to: + response = RedirectResponse(url=return_to, status_code=302) + + else: + response.body = b'{"message": "login successful"}' + response.status_code = 200 + + response.set_cookie( + key="sessionData", + value=cookie_session_data, + path="/auth", + httponly=True, # Prevent JavaScript access + secure=True, # Send only over HTTPS + samesite="Lax", # Protect against CSRF + # set expiry based on access token expiry + max_age=auth0_tokens["expires_in"], + ) + + return response + else: + raise HTTPException( + status_code=400, detail="Failed to exchange code for tokens.") + + # Register the login route + @self.app.get("/auth/login") + async def manage_login(request: Request, response: Response, + return_to: str | None = None, scope: str | None = None, connection: str | None = None): + + # check cookie for existing session + auth_cookie = request.cookies.get("sessionData") + if auth_cookie: + decoded_data = jwt.decode( + auth_cookie, self.secret_key, algorithms=["HS256"]) + # Session cookie exists, do something with it + # ... + return {"session": decoded_data} + else: + # No session cookie, redirect to Auth0 + _scope = scope or "openid profile email" + _connection = connection or "Username-Password-Authentication" + + state = self._generate_state(return_to=return_to) + + auth_url = self.get_authorize_url( + state=state, connection=_connection, scope=_scope) + + return RedirectResponse(url=auth_url, status_code=302) + + @self.app.get("/auth/get_user") + async def get_user(request: Request): + """Reads the session cookie and extracts user info.""" + auth_cookie = request.cookies.get("sessionData") + + if not auth_cookie: + raise HTTPException( + status_code=401, detail="Missing session cookie.") + + try: + # Decode the JWT stored in the session cookie + decoded_data = jwt.decode( + auth_cookie, self.secret_key, algorithms=["HS256"]) + + # Extract the user ID (sub) from the decoded JWT + user_id = decoded_data.get('user').get('sub') + + if not user_id: + raise HTTPException( + status_code=400, detail="Invalid session cookie: Missing 'sub' claim.") + + return (JSONResponse(content=decoded_data)) + + except jwt.ExpiredSignatureError: + raise HTTPException( + status_code=401, detail="Session cookie has expired.") + except jwt.InvalidTokenError: + raise HTTPException( + status_code=401, detail="Invalid session cookie.") + + @self.app.get("/auth/logout") + async def manage_logout(request: Request, response: Response): + """Reads the session cookie and extracts user info.""" + auth_cookie = request.cookies.get("sessionData") + + if not auth_cookie: + raise HTTPException( + status_code=401, detail="Missing session cookie.") + + try: + # Decode the JWT stored in the session cookie + decoded_data = jwt.decode( + auth_cookie, self.secret_key, algorithms=["HS256"]) + + # Extract the user ID (sub) from the decoded JWT + user_id = decoded_data.get('user').get('sub') + + if not user_id: + raise HTTPException( + status_code=400, detail="Invalid session cookie: Missing 'sub' claim.") + + self.get(url=f"https://{self.domain}/v2/logout") + rt = token=decoded_data.get("tokens").get("refresh_token", None) + if rt: + rt_manager = RevokeToken(self.domain, self.client_id, self.client_secret) + rt_manager.revoke_refresh_token(token=decoded_data.get("tokens").get("refresh_token")) + + response.delete_cookie(key="sessionData",path="/auth") + self.session_store._delete_stored_session(user_id) + + # MODIFY RESPONSE to ensure it returns properly + response.body = b'{"message": "logout successful"}' + response.status_code = 200 + response.media_type = "application/json" + + return response + + except jwt.ExpiredSignatureError: + raise HTTPException( + status_code=401, detail="Session cookie has expired.") + except jwt.InvalidTokenError: + raise HTTPException( + status_code=401, detail="Invalid session cookie.") + + + # Start middleware server in a separate thread + self.host = urllib.parse.urlparse(self.redirect_uri).hostname + self.port = urllib.parse.urlparse(self.redirect_uri).port + self.protocol = urllib.parse.urlparse(self.redirect_uri).scheme + self._start_server() + + def _is_valid_file(self,file_path): + """Check if the file exists and is accessible.""" + return os.path.isfile(file_path) and os.access(file_path, os.R_OK) + + def _start_server(self): + """Runs FastAPI as the middleware inside a separate thread.""" + if (self.protocol == "https"): + + ssl_keyfile = os.getenv("AUTH0_SSL_KEYFILE") + ssl_certfile = os.getenv("AUTH0_SSL_CERTFILE") + + if not self._is_valid_file(ssl_keyfile) or not self._is_valid_file(ssl_certfile): + raise ValueError( + "AUTH0_SSL_KEYFILE and AUTH0_SSL_CERTFILE environment variables must be set with valid file paths for HTTPS.") + + server_thread = threading.Thread( + target=uvicorn.run, + args=(self.app,), + kwargs={ + "host": self.host, + "port": self.port, + "ssl_keyfile": ssl_keyfile, # Path to private key + "ssl_certfile": ssl_certfile, # Path to certificate + "log_level": "error"}, + daemon=True # Daemon mode so it exits when the main thread exits + ) + else: + server_thread = threading.Thread( + target=uvicorn.run, + args=(self.app,), + kwargs={ + "host": self.host, + "port": self.port, + "log_level": "info"}, + daemon=True # Daemon mode so it exits when the main thread exits + ) + try: + server_thread.start() + except Exception as e: + print(f"Error starting middleware server: {str(e)}") + + def _generate_state(self, return_to: str | None = None) -> str: + """Generate a secure random state and store it for validation.""" + state = secrets.token_urlsafe(16) # Generate a random state + # Store it temporarily and flag it as false as we havent received it back as yet + self.state_store[state] = {"is_competed": False, "return_to": return_to} + return state + + def _get_token_set(self, token_data: str, existing_refresh_token: str | None = None) -> dict: + """Extracts the access token, scope, refresh token, and expiry time from the token_data.""" + token_data = { + "access_token": token_data.get("access_token"), + "scope": token_data.get("scope"), + "expires_at": {"epoch": int(time.time())+token_data["expires_in"]}, + "refresh_token": token_data.get("refresh_token", existing_refresh_token), + } + + return token_data + + def _get_linked_details(self, token_data: dict, existing_linked_connections: list[str] | None = None) -> list[str]: + """Extracts unique link_with values from authorization_details if type is 'account_linking' + and appends them to existing_linked_connections, avoiding duplicates. + """ + authz_details = token_data.get("authorization_details", []) + # Use a set to enforce uniqueness + linked_connections = set(existing_linked_connections or []) + + for item in authz_details: + if item.get("type") == "account_linking": + link_with = item.get("linkParams", {}).get("link_with") + if link_with: # Ensure link_with is not None + # Add to set to avoid duplicates + linked_connections.add(link_with) + + # Convert back to a list before returning + return list(linked_connections) + + async def _set_encrypted_session(self, token_data, state: str | None = None) -> str: + id_token = token_data.get("id_token", "") + decoded_id_token = {} + if id_token: + try: + decoded_id_token = await self.token_verifier.verify_signature(id_token) + user_id = decoded_id_token.get("sub") # Primary Key + if not user_id: + raise HTTPException( + status_code=400, detail="ID token missing 'sub' claim.") + except Exception as e: + raise HTTPException( + status_code=400, detail=f"Invalid ID token: {str(e)}") + else: + # this can happen in the case of the linking flow if the openid scope is not requested + # and no ID token is issued + user_id = self.state_store[state].get("user_id") + + existing_encrypted_session = self.session_store._get_stored_session( + user_id) + existing_linked_connections = {} + existing_refresh_token = {} + + if existing_encrypted_session: + # found existing session, check if there is a refresh token to keep + existing_session = self._get_encrypted_session(user_id) + if existing_session: + existing_refresh_token = existing_session.get( + "tokens").get("refresh_token", None) + existing_linked_connections = existing_session.get( + "linked_connections", None) + + session_data = {} + session_data = {"user": decoded_id_token, + "id_token": {"id_token": id_token,"id_token_expiry": decoded_id_token.get("exp")}, + "tokens": self._get_token_set(token_data, existing_refresh_token), + "linked_connections": self._get_linked_details(token_data, existing_linked_connections) + } + + encrypted_session_data = jwt.encode( + session_data, self.secret_key, algorithm="HS256") + + # Stored in memory & auto-persisted + self.session_store._set_stored_session( + user_id=user_id, encrypted_session_data=encrypted_session_data) + + self.state_store[state]["user_id"] = user_id + + # print("Session created/updated for:",user_id) + return encrypted_session_data + + def _get_encrypted_session(self, user_id): + encrypted_session = self.session_store._get_stored_session(user_id) + + if not encrypted_session: + return {"not found"} + + try: + # Decode the JWT stored in the session cookie + decoded_data = jwt.decode( + encrypted_session, self.secret_key, algorithms=["HS256"]) + + # Extract the user ID (sub) from the decoded JWT + user_id = decoded_data.get('user').get('sub') + + token_expiry = decoded_data.get("tokens", {}).get( + "expires_at", {}).get("epoch") + + if token_expiry > int(time.time()): + return decoded_data + else: + refresh_token = decoded_data.get( + "tokens", {}).get("refresh_token") + if refresh_token: + self._update_encrypted_session(user_id, refresh_token) + else: + self.session_store._delete_stored_session(user_id) + return {"session expired"} + + except jwt.ExpiredSignatureError: + return {"Session cookie has expired."} + except jwt.InvalidTokenError: + return {"Invalid session."} + + def _update_encrypted_session(self, user_id, refresh_token): + token_manager = GetToken(self.domain, self.client_id, self.client_secret) + updated_tokens = token_manager.refresh_token( + refresh_token=refresh_token) + + if updated_tokens: + self._set_encrypted_session(updated_tokens) + + def get_authorize_url( + self, + state: str, + connection: str | None = None, + scope: str | None = None, + additional_scopes: str | None = None, + **kwargs, + ) -> str: + base_url = ( + f"https://{self.domain}/authorize?" + f"response_type=code&" + f"client_id={self.client_id}&" + f"redirect_uri={self.redirect_uri}&" + f"grant_type=authorization_code&" + f"state={state}&" + ) + + if connection is not None: + base_url += f"connection={connection}&" + + if scope is not None: + base_url += f"scope={scope}&" + + if additional_scopes is not None: + base_url += f"connection_scope={additional_scopes}&" + + # Add any additional custom arguments passed via kwargs + custom_args = "&".join( + [f"{key}={value}" for key, value in kwargs.items()]) + if custom_args: + base_url += custom_args + + return base_url + + def get_authorize_par_url( + self, + state: str, + request_uri: str, + ) -> str: + + base_url = ( + f"https://{self.domain}/authorize?" + f"client_id={self.client_id}&" + f"state={state}&" + f"request_uri={request_uri}" + ) + + return base_url + + def _exchange_code_for_tokens( + self, + code: str, + ) -> dict[str, Any]: + + get_token = GetToken(self.domain, self.client_id, self.client_secret) + + token_info = get_token.authorization_code( + code=code, + redirect_uri=self.redirect_uri, + grant_type="authorization_code" + ) + + return token_info + + def get_upstream_token( + self, + connection, + refresh_token: str, + additional_scopes: str | None = None, + ) -> dict[str, Any]: + + fcat = GetToken(self.domain, self.client_id, self.client_secret) + + x = fcat.federated_connection_access_token( + subject_token_type="urn:ietf:params:oauth:token-type:refresh_token", + subject_token=refresh_token, + requested_token_type="http://auth0.com/oauth/token-type/federated-connection-access-token", + connection=connection, + grant_type="urn:auth0:params:oauth:grant-type:token-exchange:federated-connection-access-token" + ) + + return (x) + + def get_session(self, user: User) -> dict[str, Any]: + if user.user_id in self.session_store._get_stored_sessions(): + session = self._get_encrypted_session(user.user_id) + return (session.get("user")) + else: + return {"user_id not found in session store"} + + async def interactive_login(self, connection: str | None = None, scope: str | None = None, **kwargs) -> User: + + if scope is None: + scope = "openid profile email" + + state = self._generate_state() + + class LoginState: + def __init__(self, state_store, state): + self.state_store = state_store + self.state = state + self.start_time = time.time() + + def is_completed(self) -> bool: + return self.state_store[self.state].get("is_completed") + + def get_user(self) -> str: + return self.state_store.get(self.state, "login failed!") + + def terminate(self): + del self.state_store[self.state] + + login_state = LoginState(self.state_store, state) + + auth_url = self.get_authorize_url( + state=state, connection=connection, scope=scope, **kwargs) + + # this initiates the login flow, and if successful the session is created + try: + webbrowser.open(auth_url) + except webbrowser.Error: + print("Please navigate here: {auth_url}") + + # check of compleition of login flow + while not login_state.is_completed(): + # not sure if this needed, but can save some unecessary polling time + if (time.time() < login_state.start_time + 60): + time.sleep(0.25) + else: + # login has timed out, we can clean up state + login_state.terminate() + return ("login timeout") + + user_id = login_state.get_user() + + # no longer need to retain the state anymore, we can clean it up + login_state.terminate() + + if user_id == "login failed!": + return "login failed" + else: + return User(self, user_id=user_id.get("user_id")) + + async def link(self, primary_user_id: str, connection: str, id_token: str, scope: str | None = None, **kwargs) -> str: + + state = self._generate_state() + + class LinkState: + def __init__(self, state_store, state): + self.state_store = state_store + self.state = state + self.start_time = time.time() + + def is_completed(self) -> bool: + return self.state_store[self.state].get("is_completed", False) + + def get_user(self) -> str: + return self.state_store.get(self.state, "login failed!").get("user_id") + + def set_user(self, user_id: str) -> None: + is_comeplted = self.state_store[self.state].get( + "is_competed", False) + self.state_store[self.state] = { + "is_completed": is_comeplted, "user_id": user_id} + + def terminate(self): + del self.state_store[self.state] + + par_client = PushedAuthorizationRequests( + self.domain, self.client_id, self.client_secret) + + link_state = LinkState(self.state_store, state) + link_state.set_user(primary_user_id) + + try: + y = par_client.pushed_authorization_request( + response_type="code", + nonce="mynonce", + redirect_uri=self.redirect_uri, + audience="my-account", + state=state, + authorization_details=json.dumps([ + {"type": "link_account", "requested_connection": connection} + ]), + scope="openid profile", + id_token_hint=id_token, + ) + + request_uri = y.get('request_uri') + + if request_uri: + auth_url = self.get_authorize_par_url( + state=state, request_uri=request_uri) + + # this initiates the login flow for the connection to link + try: + webbrowser.open(auth_url) + except webbrowser.Error: + print("Please navigate here:", auth_url) + + # check of compleition of login flow + while not link_state.is_completed(): + # not sure if this needed, but can save some unecessary polling time + if (time.time() < link_state.start_time + 60): + time.sleep(0.05) + else: + link_state.terminate() + return ("linking timeout") + + user_id = link_state.get_user() + + # no longer need state, we can clean it up + link_state.terminate() + + if user_id == "login failed!": + successul_login = False + else: + successul_login = True + + link_response = { + "is_successful": successul_login, + "user_id": user_id, + } + else: + link_response = { + "is_successful": False, + "user_id": user_id, + } + + except Exception as error: + print(error) + link_response = { + "is_successful": False, + "user_id": primary_user_id, + } + + return link_response + + +class User(AIAuth): + + def __init__(self, parent, user_id: str): + self.parent = parent + self.client = parent.client + self.user_id = user_id + + async def link(self, connection: str | None = None, scope: str | None = None, **kwargs) -> str: + return await self.parent.link(connection=connection, primary_user_id=self.user_id, id_token=self.get_id_token(), scope=scope, **kwargs) + + def get_id_token(self) -> str: + if self.user_id in self.parent.session_store._get_stored_sessions(): + return (self.parent._get_encrypted_session(self.user_id).get("tokens").get("id_token")) + else: + return {"user_id not found in session store"} + + def get_access_token(self) -> str: + if self.user_id in self.parent.session_store._get_stored_sessions(): + return (self.parent._get_encrypted_session(self.user_id).get("tokens").get("access_token")) + else: + return {"user_id not found in session store"} + + def get_refresh_token(self) -> str: + if self.user_id in self.parent.session_store._get_stored_sessions(): + return (self.parent._get_encrypted_session(self.user_id).get("tokens").get("refresh_token")) + else: + return {"user_id not found in session store"} + + def get_3rd_party_token(self, connection: str) -> dict[str, Any]: + return self.parent.get_upstream_token(connection, self.get_refresh_token()) + + def tokeninfo(self) -> dict[str, Any]: + id_token = self.get_id_token() + self.parent.post() + data: dict[str, Any] = self.parent.get( + url=f"https://{self.parent.domain}/tokeninfo", + headers={"Authorization": f"Bearer {access_token}"}, + ) + return data + + def userinfo(self, access_token: str | None = None) -> dict[str, Any]: + access_token = access_token or self.get_access_token() + data: dict[str, Any] = self.parent.get( + url=f"https://{self.parent.domain}/userinfo", + headers={"Authorization": f"Bearer {access_token}"}, + ) + return data diff --git a/packages/auth0-ai/auth0_ai/auth/__init__.py b/packages/auth0-ai/auth0_ai/auth/__init__.py new file mode 100644 index 0000000..1e15861 --- /dev/null +++ b/packages/auth0-ai/auth0_ai/auth/__init__.py @@ -0,0 +1,5 @@ +from .auth_client import AIAuth, User +from .base import BaseAuth +from .user import User + +__all__ = ["AIAuth", "BaseAuth", "User"] diff --git a/packages/auth0-ai/auth0_ai/auth/auth_client.py b/packages/auth0-ai/auth0_ai/auth/auth_client.py new file mode 100644 index 0000000..d13d4d5 --- /dev/null +++ b/packages/auth0-ai/auth0_ai/auth/auth_client.py @@ -0,0 +1,222 @@ +from __future__ import annotations +from typing import Any +import webbrowser +import secrets +import json +from typing import Any, Dict + +from auth0.authentication.async_token_verifier import AsyncAsymmetricSignatureVerifier +from auth0.authentication.pushed_authorization_requests import PushedAuthorizationRequests + +from .base import BaseAuth +from .user import User +from auth0_ai.server.auth_server import AuthServer +from auth0_ai.token_module.manager import TokenManager +from auth0_ai.session_module.manager import SessionManager +from auth0_ai.state.login_state import LoginState +from auth0_ai.state.link_state import LinkState +from auth0_ai.utils.url_builder import URLBuilder + +class AIAuth(BaseAuth): + """Main authentication class that orchestrates the auth flow""" + def __init__( + self, + domain: str | None = None, + client_id: str | None = None, + client_secret: str | None = None, + redirect_uri: str | None = None, + secret_key: str | None = None, + *args, **kwargs): + """Initialize AIAuth with all necessary components""" + super().__init__( + domain=domain, + client_id=client_id, + client_secret=client_secret, + redirect_uri=redirect_uri, + secret_key=secret_key, + *args, **kwargs + ) + # Initialize token verifier + jwk_url = f"https://{self.domain}/.well-known/jwks.json" + self.token_verifier = AsyncAsymmetricSignatureVerifier(jwks_url=jwk_url) + # Initialize components + self.state_store: Dict[str, Dict[str, Any]] = {} + self.session_manager = SessionManager(self) + self.token_manager = TokenManager(self) + self.url_builder = URLBuilder(self) + # Initialize server + self.server= AuthServer(self) + + def _generate_state(self) -> str: + """Generate a secure random state and store it""" + state = secrets.token_urlsafe(16) + self.state_store[state] = {"is_completed": False} + return state + async def interactive_login( + self, + connection: str | None = None, + scope: str | None = None, + **kwargs + ) -> User: + """ + Handle interactive login flow. + Args: + connection: Optional connection to use + scope: OAuth scope (default: "openid profile email") + **kwargs: Additional parameters for authorization + Returns: + User instance if successful, error string if failed + """ + if scope is None: + scope = "openid profile email" + # Generate state and create login state tracker + state = self._generate_state() + login_state = LoginState(self.state_store, state) + # Generate authorization URL + auth_url = self.url_builder.get_authorize_url( + state=state, + connection=connection, + scope=scope, + **kwargs + ) + # Open browser for authentication + try: + webbrowser.open(auth_url) + except webbrowser.Error: + print(f"Please navigate here: {auth_url}") + # Wait for authentication completion + user_id = await login_state.wait_for_completion() + if not user_id: + return "login failed" + return User(self, user_id=user_id) + + async def link( + self, + primary_user_id: str, + connection: str, + id_token: str, + scope: str | None = None, + **kwargs + ) -> Dict[str, Any]: + """ + Handle account linking flow. + Args: + primary_user_id: ID of the user initiating the link + connection: Connection to link + id_token: ID token of the primary user + scope: OAuth scope + **kwargs: Additional parameters for authorization + Returns: + Dict containing link status and user information + """ + state = self._generate_state() + link_state = LinkState(self.state_store, state) + link_state.set_user(primary_user_id) + + auth_url = self.url_builder.get_authorize_url( + state = state, + scope = "link_account", + audience = "my-account", + requested_connection = connection, + requested_connection_scope = scope, + id_token_hint = id_token, + client_id = self.client_id, + redirect_uri = self.redirect_uri, + ) + + try: + try: + webbrowser.open(auth_url) + except webbrowser.Error: + print(f"Please navigate here: {auth_url}") + + # Wait for linking completion + user_id = await link_state.wait_for_completion() + return { + "is_successful": bool(user_id), + "user_id": user_id or primary_user_id + } + except Exception as error: + print(f"Error during linking: {error}") + return { + "is_successful": False, + "user_id": primary_user_id, + } + + async def unlink( + self, + primary_user_id: str, + connection: str, + id_token: str, + **kwargs + ) -> Dict[str, Any]: + """ + Handle account linking flow. + Args: + primary_user_id: ID of the user initiating the link + connection: Connection to link + id_token: ID token of the primary user + scope: OAuth scope + **kwargs: Additional parameters for authorization + Returns: + Dict containing link status and user information + """ + state = self._generate_state() + link_state = LinkState(self.state_store, state) + link_state.set_user(primary_user_id) + + auth_url = self.url_builder.get_authorize_url( + state = state, + scope = "unlink_account", + audience = "my-account", + requested_connection = connection, + id_token_hint = id_token, + client_id = self.client_id, + redirect_uri = self.redirect_uri, + ) + + try: + try: + webbrowser.open(auth_url) + except webbrowser.Error: + print(f"Please navigate here: {auth_url}") + + # Wait for linking completion + user_id = await link_state.wait_for_completion() + return { + "is_successful": bool(user_id), + "user_id": user_id or primary_user_id + } + except Exception as error: + print(f"Error during unlinking: {error}") + return { + "is_successful": False, + "user_id": primary_user_id, + } + + def get_session(self, user: User) -> Dict[str, Any]: + """Get session for a user object""" + return self.session_manager.get_session(user = user) + + + def get_upstream_token( + self, + connection: str, + refresh_token: str, + additional_scopes: str | None = None + ) -> Dict[str, Any]: + """Get token for federated connection""" + return self.token_manager.get_upstream_token( + connection=connection, + refresh_token=refresh_token, + additional_scopes=additional_scopes + ) + + + + + + + + + diff --git a/packages/auth0-ai/auth0_ai/auth/base.py b/packages/auth0-ai/auth0_ai/auth/base.py new file mode 100644 index 0000000..f7cfd81 --- /dev/null +++ b/packages/auth0-ai/auth0_ai/auth/base.py @@ -0,0 +1,47 @@ +from __future__ import annotations +import os +from auth0.authentication.base import AuthenticationBase + +class BaseAuth(AuthenticationBase): + """Base authentication class with core properties and validations""" + # Define required config fields and their environment variable names + REQUIRED_CONFIGS = { + 'domain': 'AUTH0_DOMAIN', + 'client_id': 'AUTH0_CLIENT_ID', + 'client_secret': 'AUTH0_CLIENT_SECRET', + 'redirect_uri': 'AUTH0_REDIRECT_URI', + 'secret_key': 'AUTH0_SECRET_KEY' + } + + def __init__( + self, + domain: str | None = None, + client_id: str | None = None, + client_secret: str | None = None, + redirect_uri: str | None = None, + secret_key: str | None = None, + *args, **kwargs): + + # Initialize all config properties + for field, env_var in self.REQUIRED_CONFIGS.items(): + value = locals().get(field) or os.environ.get(env_var) + setattr(self, f'_{field}', None) # Initialize private attribute + setattr(self.__class__, field, property( # Create property + fget=lambda self, f=field: getattr(self, f'_{f}'), + fset=lambda self, value, f=field: self._validate_and_set(f, value) + )) + setattr(self, field, value) # Set the value using property setter + + super().__init__( + domain=self.domain, + client_id=self.client_id, + client_secret=self.client_secret, + *args, **kwargs + ) + + def _validate_and_set(self, field: str, value: str | None) -> None: + """Validate and set a configuration value""" + if not value: + raise ValueError( + f"{field} cannot be empty. You can also set {self.REQUIRED_CONFIGS[field]} value in .env file") + setattr(self, f'_{field}', value) \ No newline at end of file diff --git a/packages/auth0-ai/auth0_ai/auth/user.py b/packages/auth0-ai/auth0_ai/auth/user.py new file mode 100644 index 0000000..880a013 --- /dev/null +++ b/packages/auth0-ai/auth0_ai/auth/user.py @@ -0,0 +1,128 @@ +from __future__ import annotations +from typing import Any, Dict + +class User: + """ + Represents an authenticated user with token management capabilities. + """ + def __init__(self, auth_client, user_id: str): + """ + Initialize a User instance. + Args: + auth_client: The parent AIAuth instance + user_id: The unique identifier for the user + """ + self._auth_client = auth_client + self._user_id = user_id + @property + def user_id(self) -> str: + """Get the user's ID""" + return self._user_id + async def link(self, connection: str, scope: str | None = None, **kwargs) -> Dict[str, Any]: + """ + Link another authentication provider to this user account. + Args: + connection: The name of the connection to link (e.g., 'github', 'google') + scope: OAuth scope for the connection + **kwargs: Additional parameters for the linking process + Returns: + Dict containing link status and user information + """ + return await self._auth_client.link( + primary_user_id=self.user_id, + connection=connection, + id_token=self.get_id_token(), + scope=scope, + **kwargs + ) + async def unlink(self, connection: str, scope: str | None = None, **kwargs) -> Dict[str, Any]: + """ + UnLink an existing authentication provider to this user account. + Args: + connection: The name of the connection to unlink (e.g., 'github', 'google') + **kwargs: Additional parameters for the unlinking process + Returns: + Dict containing unlink status and user information + """ + return await self._auth_client.unlink( + primary_user_id=self.user_id, + connection=connection, + id_token=self.get_id_token(), + scope=scope, + **kwargs + ) + def get_id_token(self) -> str: + """Get the user's ID token""" + return self._auth_client.token_manager.get_id_token(self.user_id) + + def get_access_token(self) -> str: + """Get the user's access token""" + return self._auth_client.token_manager.get_access_token(self.user_id) + + def get_refresh_token(self) -> str: + """Get the user's refresh token""" + return self._auth_client.token_manager.get_refresh_token(self.user_id) + + def get_3rd_party_token(self, connection: str) -> Dict[str, Any]: + """ + Get access token for a linked third-party provider. + Args: + connection: The name of the third-party connection (e.g., 'github') + Returns: + Dict containing the third-party access token and related information + """ + refresh_token = self.get_refresh_token() + return self._auth_client.get_upstream_token(connection, refresh_token) + + def get_profile(self) -> Dict[str, Any]: + """ + Get the user's profile information. + Returns: + Dict containing user profile data + """ + access_token = self.get_access_token() + return self._auth_client.token_manager.get_userinfo(access_token) + + async def get_token_info(self) -> Dict[str, Any]: + """ + Get detailed information about the user's tokens. + Returns: + Dict containing token information + """ + id_token = self.get_id_token() + access_token = self.get_access_token() + return await self._auth_client.token_manager.get_tokeninfo(id_token, access_token) + + def is_token_valid(self) -> bool: + """ + Check if the user's tokens are still valid. + Returns: + bool indicating token validity + """ + return self._auth_client.token_manager.validate_tokens(self.user_id) + + async def refresh_tokens(self) -> bool: + """ + Refresh the user's tokens if they're expired. + Returns: + bool indicating if refresh was successful + """ + refresh_token = self.get_refresh_token() + if not refresh_token: + return False + return await self._auth_client.token_manager.refresh_tokens(self.user_id, refresh_token) + + def get_session(self) -> Dict[str, Any]: + """ + Get the current session information for the user. + Returns: + Dict containing session information + """ + return self._auth_client.get_session(self) + + + + + + + diff --git a/packages/auth0-ai/auth0_ai/server/__init__.py b/packages/auth0-ai/auth0_ai/server/__init__.py new file mode 100644 index 0000000..4dee87a --- /dev/null +++ b/packages/auth0-ai/auth0_ai/server/__init__.py @@ -0,0 +1,8 @@ +""" +Auth0 AI Server Module +Internal module for handling OAuth callback server and routes. +""" +from .auth_server import AuthServer +from .routes import setup_routes + +__all__ = ["AuthServer","setup_routes"] diff --git a/packages/auth0-ai/auth0_ai/server/auth_server.py b/packages/auth0-ai/auth0_ai/server/auth_server.py new file mode 100644 index 0000000..a371042 --- /dev/null +++ b/packages/auth0-ai/auth0_ai/server/auth_server.py @@ -0,0 +1,83 @@ +from __future__ import annotations +import threading +import urllib.parse +from typing import Any + +import uvicorn +import os +from fastapi import FastAPI + +from .routes import setup_routes + +class AuthServer: + """ + FastAPI server handling Auth0 callbacks and authentication routes. + """ + + def __init__(self, auth_client: Any): + """ + Initialize the authentication server. + + Args: + auth_client: The parent AIAuth instance + """ + self.auth_client = auth_client + self.app = FastAPI() + + # Parse redirect URI for server config + parsed_uri = urllib.parse.urlparse(auth_client.redirect_uri) + self.host = parsed_uri.hostname + self.port = parsed_uri.port + self.protocol = urllib.parse.urlparse(auth_client.redirect_uri).scheme + + # Setup routes with dependencies + setup_routes(self.app, auth_client) + self.start() + + def _is_valid_file(self,file_path) -> bool: + """Check if the file exists and is accessible.""" + valid = False + try: + valid = os.path.isfile(file_path) and os.access(file_path, os.R_OK) + return valid + except Exception as e: + return valid + + + def start(self): + """Runs FastAPI as the middleware inside a separate thread.""" + if (self.protocol == "https"): + + ssl_keyfile = os.getenv("AUTH0_SSL_KEYFILE") + ssl_certfile = os.getenv("AUTH0_SSL_CERTFILE") + + if not self._is_valid_file(ssl_keyfile) or not self._is_valid_file(ssl_certfile): + raise ValueError( + "AUTH0_SSL_KEYFILE and AUTH0_SSL_CERTFILE environment variables must be set with valid file paths for HTTPS.") + + server_thread = threading.Thread( + target=uvicorn.run, + args=(self.app,), + kwargs={ + "host": self.host, + "port": self.port, + "ssl_keyfile": ssl_keyfile, # Path to private key + "ssl_certfile": ssl_certfile, # Path to certificate + "log_level": "error"}, + daemon=True # Daemon mode so it exits when the main thread exits + ) + else: + server_thread = threading.Thread( + target=uvicorn.run, + args=(self.app,), + kwargs={ + "host": self.host, + "port": self.port, + "log_level": "info"}, + daemon=True # Daemon mode so it exits when the main thread exits + ) + try: + server_thread.start() + except Exception as e: + print(f"Error starting middleware server: {str(e)}") + raise e \ No newline at end of file diff --git a/packages/auth0-ai/auth0_ai/server/routes.py b/packages/auth0-ai/auth0_ai/server/routes.py new file mode 100644 index 0000000..549fd18 --- /dev/null +++ b/packages/auth0-ai/auth0_ai/server/routes.py @@ -0,0 +1,105 @@ +from __future__ import annotations +from typing import Any + +from fastapi import FastAPI, Request, Response, HTTPException +from fastapi.responses import JSONResponse, RedirectResponse +import jwt + + +def setup_routes(app: FastAPI, auth_client: Any) -> None: + """Set up all routes for the authentication server.""" + + @app.get("/auth/callback") + async def manage_callback(request: Request, response: Response): + """Parses and validates callback URL query parameters.""" + query_params = request.query_params + required_keys = {"code", "state"} + + if query_params.get("error"): + error_description = query_params.get("error_description", "Unknown error occurred.") + if query_params.get("state"): + del auth_client.state_store[query_params.get("state")] + raise HTTPException(status_code=400, detail=error_description) + + if not required_keys.issubset(query_params.keys()): + raise HTTPException(status_code=400, detail="Missing required query parameters.") + + received_state = query_params["state"] + + # Validate state to prevent CSRF attacks + if received_state not in auth_client.state_store: + raise HTTPException(status_code=400, detail="Invalid or missing state parameter.") + + # Extract code value from query string + received_code = query_params["code"] + + auth0_tokens = auth_client.token_manager.exchange_code_for_tokens(received_code) + + if auth0_tokens: + cookie_data = await auth_client.session_manager.set_encrypted_session(auth0_tokens, state=received_state) + + response.set_cookie( + key="session", + value=cookie_data, + httponly=True, # Prevent JavaScript access + # secure=True, # Send only over HTTPS + samesite="Lax", # Protect against CSRF + # set expiry based on access token expiry + max_age=auth0_tokens["expires_in"], + ) + + user_id = auth_client.state_store[received_state].get("user_id", "failed") + auth_client.state_store[received_state] = {"user_id": user_id, "is_completed": True} + + return {"message": "successful. you can now close this window"} + + @app.get("/auth/login") + async def manage_login(request: Request, response: Response): + """Handle login initiation.""" + # if scope is None: # Original comment preserved + scope = "openid profile email" + connection = "Username-Password-Authentication" + + state = auth_client._generate_state() + + auth_url = auth_client.url_builder.get_authorize_url( + state=state, + connection=connection, + scope=scope + ) + + return RedirectResponse(url=auth_url, status_code=302) + + @app.get("/auth/get_user") + async def get_user(request: Request): + """Reads the session cookie and extracts user info.""" + auth_cookie = request.cookies.get("session") + + if not auth_cookie: + raise HTTPException(status_code=401, detail="Missing session cookie.") + + try: + # Decode the JWT stored in the session cookie + decoded_data = jwt.decode(auth_cookie, auth_client.secret_key, algorithms=["HS256"]) + + # Extract the user ID (sub) from the decoded JWT + user_id = decoded_data["user_id"] + + if not user_id: + raise HTTPException(status_code=400, detail="Invalid session cookie: Missing 'sub' claim.") + + return JSONResponse(content=decoded_data) + + except jwt.ExpiredSignatureError: + raise HTTPException(status_code=401, detail="Session cookie has expired.") + except jwt.InvalidTokenError: + raise HTTPException(status_code=401, detail="Invalid session cookie.") + + + + + + + + + diff --git a/packages/auth0-ai/auth0_ai/session_module/__init__.py b/packages/auth0-ai/auth0_ai/session_module/__init__.py new file mode 100644 index 0000000..8e41a90 --- /dev/null +++ b/packages/auth0-ai/auth0_ai/session_module/__init__.py @@ -0,0 +1,12 @@ +""" +Session Management Module +Provides session handling, storage, and encryption capabilities. +""" +from .manager import SessionManager +from .storage.base_store import BaseStore +from .storage.local_store import LocalStore +__all__ = [ + "SessionManager", + "BaseStore", + "LocalStore" +] diff --git a/packages/auth0-ai/auth0_ai/session_module/manager.py b/packages/auth0-ai/auth0_ai/session_module/manager.py new file mode 100644 index 0000000..2b2833e --- /dev/null +++ b/packages/auth0-ai/auth0_ai/session_module/manager.py @@ -0,0 +1,181 @@ +from __future__ import annotations +from typing import Any, Dict, Optional +import jwt +import time + +from .storage.base_store import BaseStore +from .storage.local_store import LocalStore + +class SessionManager: + """ + Manages session operations including encryption, storage, and retrieval. + Maintains exact compatibility with original implementation. + """ + + def __init__( + self, + auth_client: Any, + use_local_cache: bool = True, + get_ext_sessions=None, + get_ext_session=None, + set_ext_session=None, + delete_ext_session=None, + store: Optional[BaseStore] = None + ): + """ + Initialize session manager with original parameters plus optional store. + + Args: + auth_client: Parent AIAuth instance + use_local_cache: Whether to use local cache (default: True) + get_ext_sessions: Optional custom get_sessions function + get_ext_session: Optional custom get_session function + set_ext_session: Optional custom set_session function + delete_ext_session: Optional custom delete_session function + store: Optional custom store implementation + """ + self.auth_client = auth_client + self.store = store or LocalStore(use_local_cache=use_local_cache) + self.secret_key = auth_client.secret_key + + # Custom function handlers + self.get_ext_sessions = get_ext_sessions + self.get_ext_session = get_ext_session + self.set_ext_session = set_ext_session + self.delete_ext_session = delete_ext_session + + # Original interface methods with exact same names and signatures + def _get_stored_sessions(self) -> Any: + """Get all stored session IDs""" + if hasattr(self, 'get_ext_sessions') and self.get_ext_sessions: + return self.get_ext_sessions() + return self.store.get_stored_sessions() + + def _get_stored_session(self, user_id: str) -> str: + """Get a specific stored session""" + if hasattr(self, 'get_ext_session') and self.get_ext_session: + return self.get_ext_session() + return self.store.get_stored_session(user_id) + + def _set_stored_session(self, user_id: str, encrypted_session_data: str) -> None: + """Store a session""" + if hasattr(self, 'set_ext_session') and self.set_ext_session: + self.set_ext_session() + else: + self.store.set_stored_session(user_id, encrypted_session_data) + + def _delete_stored_session(self, user_id: str) -> None: + """Delete a stored session""" + if hasattr(self, 'delete_ext_session') and self.delete_ext_session: + self.delete_ext_session() + else: + self.store.delete_stored_session(user_id) + + # Session encryption and management methods (from original auth_client.py) + async def set_encrypted_session(self, token_data: dict, state: str | None = None) -> str: + """Create or update encrypted session""" + id_token = token_data.get("id_token", "") + if id_token: + try: + decoded_id_token = await self.auth_client.token_verifier.verify_signature(id_token) + user_id = decoded_id_token.get("sub") + if not user_id: + raise ValueError("ID token missing 'sub' claim.") + except Exception as e: + raise ValueError(f"Invalid ID token: {str(e)}") + else: + user_id = self.auth_client.state_store[state].get("user_id") if state else None + + existing_encrypted_session = self._get_stored_session(user_id) + existing_linked_connections = {} + existing_refresh_token = {} + + if existing_encrypted_session: + existing_session = self.get_encrypted_session(user_id) + existing_refresh_token = existing_session.get("tokens", {}).get("refresh_token", None) + existing_linked_connections = existing_session.get("linked_connections") + + session_data = { + "user_id": user_id, + "tokens": self._get_token_set(token_data, existing_refresh_token), + "linked_connections": self._get_linked_details(token_data, existing_linked_connections) + } + + encrypted_session_data = jwt.encode(session_data, self.secret_key, algorithm="HS256") + self._set_stored_session(user_id, encrypted_session_data) + + if state: + self.auth_client.state_store[state] = {"user_id": user_id} + + return encrypted_session_data + + def get_encrypted_session(self, user_id: str) -> Dict[str, Any]: + """Retrieve and decrypt session data""" + encrypted_session = self._get_stored_session(user_id) + + if not encrypted_session: + return {"not found"} + + try: + decoded_data = jwt.decode(encrypted_session, self.secret_key, algorithms=["HS256"]) + + token_expiry = decoded_data.get("tokens", {}).get("expires_at", {}).get("epoch") + if token_expiry > int(time.time()): + return decoded_data + + refresh_token = decoded_data.get("tokens", {}).get("refresh_token") + if refresh_token: + self._update_encrypted_session(user_id, refresh_token) + else: + self._delete_stored_session(user_id) + return {"session expired"} + + except jwt.ExpiredSignatureError: + return {"Session cookie has expired."} + except jwt.InvalidTokenError: + return {"Invalid session."} + + def _update_encrypted_session(self, user_id: str, refresh_token: str) -> None: + """Update session with refreshed tokens""" + token_manager = self.auth_client.token_manager + updated_tokens = token_manager.refresh_token(refresh_token=refresh_token) + if updated_tokens: + self.set_encrypted_session(updated_tokens) + + + def get_session_details(self, user_id: str) -> Dict[str, Any]: + """Get session details for user""" + if user_id in self._get_stored_sessions(): + return self.get_encrypted_session(user_id) + return {"user_id not found in session store"} + + def get_session(self, user: Any) -> Dict[str, Any]: + """Get session for user object""" + if user.user_id in self._get_stored_sessions(): + return self.get_encrypted_session(user.user_id) + return {"user_id not found in session store"} + + def _get_token_set(self, token_data: dict, existing_refresh_token: str | None = None) -> dict: + """Format token data with expiry time""" + return { + "access_token": token_data.get("access_token"), + "expires_at": {"epoch": int(time.time()) + token_data["expires_in"]}, + "refresh_token": token_data.get("refresh_token", existing_refresh_token), + "id_token": token_data.get("id_token"), + "scope": token_data.get("scope"), + } + + def _get_linked_details(self, token_data: dict, existing_linked_connections: list[str] | None = None) -> list[str]: + """Extract linked connections from token data""" + linked_connections = set(existing_linked_connections or []) + + for item in token_data.get("authorization_details", []): + if item.get("type") == "account_linking": + link_with = item.get("linkParams", {}).get("link_with") + if link_with: + linked_connections.add(link_with) + + return list(linked_connections) + + + diff --git a/packages/auth0-ai/auth0_ai/session_module/storage/__init__.py b/packages/auth0-ai/auth0_ai/session_module/storage/__init__.py new file mode 100644 index 0000000..bfc0e98 --- /dev/null +++ b/packages/auth0-ai/auth0_ai/session_module/storage/__init__.py @@ -0,0 +1,7 @@ +""" +Session Storage Implementations +""" +from .base_store import BaseStore +from .local_store import LocalStore + +__all__ = ["BaseStore", "LocalStore"] diff --git a/packages/auth0-ai/auth0_ai/session_module/storage/base_store.py b/packages/auth0-ai/auth0_ai/session_module/storage/base_store.py new file mode 100644 index 0000000..5e57a74 --- /dev/null +++ b/packages/auth0-ai/auth0_ai/session_module/storage/base_store.py @@ -0,0 +1,45 @@ +from __future__ import annotations +from abc import ABC, abstractmethod +from typing import List + +class BaseStore(ABC): + """ + Abstract base class defining the interface for session storage implementations. + All storage implementations must inherit from this class and implement + all abstract methods. + """ + @abstractmethod + def get_stored_sessions(self) -> List[str]: + """ + Get all stored session IDs. + Returns: + List of session IDs + """ + pass + @abstractmethod + def get_stored_session(self, user_id: str) -> str | None: + """ + Get a specific stored session. + Args: + user_id: The ID of the user whose session to retrieve + Returns: + The session data if found, None otherwise + """ + pass + @abstractmethod + def set_stored_session(self, user_id: str, encrypted_session_data: str) -> None: + """ + Store a session. + Args: + user_id: The ID of the user whose session to store + encrypted_session_data: The encrypted session data to store + """ + pass + @abstractmethod + def delete_stored_session(self, user_id: str) -> None: + """ + Delete a stored session. + Args: + user_id: The ID of the user whose session to delete + """ + pass \ No newline at end of file diff --git a/packages/auth0-ai/auth0_ai/session_module/storage/local_store.py b/packages/auth0-ai/auth0_ai/session_module/storage/local_store.py new file mode 100644 index 0000000..5a8540d --- /dev/null +++ b/packages/auth0-ai/auth0_ai/session_module/storage/local_store.py @@ -0,0 +1,51 @@ +from __future__ import annotations +import shelve +import os +from typing import List + +from .base_store import BaseStore + +class LocalStore(BaseStore): + """ + Local storage implementation using Python's shelve module. + This is the default storage mechanism, maintaining the original implementation's behavior. + """ + + def __init__(self, file_path: str = ".sessions_cache", use_local_cache: bool = True): + """ + Initialize local store. + + Args: + file_path: Path to the shelve file (default: ".sessions_cache") + use_local_cache: Flag to determine if local cache should be used (default: True) + """ + self.file_path = file_path + self.use_local_cache = use_local_cache or os.environ.get("AUTH0_USE_LOCAL_CACHE", True) + + def get_stored_sessions(self) -> List[str]: + """Get all stored session IDs""" + if self.use_local_cache: + with shelve.open(self.file_path) as sessions: + return list(sessions.keys()) + return [] + + def get_stored_session(self, user_id: str) -> str | None: + """Get a specific stored session""" + if self.use_local_cache: + with shelve.open(self.file_path) as sessions: + return sessions.get(user_id) + return None + + def set_stored_session(self, user_id: str, encrypted_session_data: str) -> None: + """Store a session""" + if self.use_local_cache: + with shelve.open(self.file_path) as sessions: + sessions[user_id] = encrypted_session_data + sessions.sync() + + def delete_stored_session(self, user_id: str) -> None: + """Delete a stored session""" + if self.use_local_cache: + with shelve.open(self.file_path) as sessions: + if user_id in sessions: + del sessions[user_id] diff --git a/packages/auth0-ai/auth0_ai/session_storage.py b/packages/auth0-ai/auth0_ai/session_storage.py new file mode 100644 index 0000000..408b59d --- /dev/null +++ b/packages/auth0-ai/auth0_ai/session_storage.py @@ -0,0 +1,61 @@ +import shelve +from typing import Any + + +class SessionStorage: + + def __init__( + self, + use_local_cache: bool = True, + get_sessions=None, + get_session=None, + set_session=None, + delete_session=None, + ): + + self.get_sessions = get_sessions + self.get_session = get_session + self.set_session = set_session + self.delete_session = delete_session + + self.use_local_cache = use_local_cache or os.environ.get( + "AUTH0_USE_LOCAL_CACHE") + + @property + def use_local_cache(self): + return self._use_local_cache + + @use_local_cache.setter + def use_local_cache(self, val): + self._use_local_cache = val + + def _get_stored_sessions(self) -> Any: + if (self.use_local_cache): + with shelve.open(".sessions_cache") as sessions: + return list(sessions.keys()) + else: + return self.get_session() + + def _get_stored_session(self, user_id: str) -> str: + if (self.use_local_cache): + with shelve.open(".sessions_cache") as sessions: + return sessions.get(user_id) + else: + return self.get_session() + + def _set_stored_session(self, user_id, encrypted_session_data): + if (self.use_local_cache): + with shelve.open(".sessions_cache") as sessions: + sessions[user_id] = encrypted_session_data + sessions.sync() + else: + self.set_session() + + def _delete_stored_session(self, user_id): + if (self.use_local_cache): + with shelve.open(".sessions_cache") as sessions: + if user_id in sessions: + del sessions[user_id] + sessions.sync() + else: + self.del_session() diff --git a/packages/auth0-ai/auth0_ai/state/__init__.py b/packages/auth0-ai/auth0_ai/state/__init__.py new file mode 100644 index 0000000..13635f8 --- /dev/null +++ b/packages/auth0-ai/auth0_ai/state/__init__.py @@ -0,0 +1,9 @@ +""" +Auth0 AI State Management Module +Internal module for handling authentication and linking state. +""" +from .base_state import BaseState +from .login_state import LoginState +from .link_state import LinkState + +__all__ = ["BaseState", "LoginState", "LinkState"] \ No newline at end of file diff --git a/packages/auth0-ai/auth0_ai/state/base_state.py b/packages/auth0-ai/auth0_ai/state/base_state.py new file mode 100644 index 0000000..9294aab --- /dev/null +++ b/packages/auth0-ai/auth0_ai/state/base_state.py @@ -0,0 +1,38 @@ +from __future__ import annotations +import asyncio +from typing import Any, Dict +from abc import ABC, abstractmethod + +class BaseState(ABC): + """ + Base class for state management in authentication flows. + """ + def __init__(self, state_store: Dict[str, Dict[str, Any]], state: str): + """ + Initialize base state tracker. + Args: + state_store: Reference to the global state store + state: Unique state identifier for this flow + """ + self.state_store = state_store + self.state = state + + @abstractmethod + def is_completed(self) -> bool: + """Check if flow is completed""" + pass + @abstractmethod + def get_user(self) -> Any: + """Get user information after flow completion""" + pass + @abstractmethod + def complete(self, user_id: str) -> None: + """Mark flow as complete with user information""" + pass + def terminate(self) -> None: + """Clean up state data""" + if self.state in self.state_store: + del self.state_store[self.state] + async def _sleep(self, seconds: float) -> None: + """Async sleep helper""" + await asyncio.sleep(seconds) \ No newline at end of file diff --git a/packages/auth0-ai/auth0_ai/state/link_state.py b/packages/auth0-ai/auth0_ai/state/link_state.py new file mode 100644 index 0000000..9a7d862 --- /dev/null +++ b/packages/auth0-ai/auth0_ai/state/link_state.py @@ -0,0 +1,65 @@ +from __future__ import annotations +import time +from typing import Any, Dict, Optional +from .base_state import BaseState + +class LinkState(BaseState): + """ + Handles the state management for the account linking flow. + """ + def __init__(self, state_store: Dict[str, Dict[str, Any]], state: str): + """ + Initialize link state tracker. + Args: + state_store: Reference to the global state store + state: Unique state identifier for this linking attempt + """ + super().__init__(state_store, state) + self.start_time = time.time() + self.timeout = 60 # Linking timeout in seconds + + def is_completed(self) -> bool: + """Check if linking flow is completed""" + return self.state_store[self.state].get("is_completed", False) + def get_user(self) -> str: + """Get user information after linking completion""" + return self.state_store.get(self.state, "login failed!").get("user_id") + def set_user(self, user_id: str) -> None: + """ + Set the primary user ID for linking. + Args: + user_id: ID of the user initiating the link + """ + is_completed = self.state_store[self.state].get("is_completed", False) + self.state_store[self.state] = { + "is_completed": is_completed, + "user_id": user_id + } + + def complete(self, user_id: str) -> None: + """ + Mark linking as complete with user information. + Args: + user_id: ID of the linked user + """ + self.state_store[self.state] = { + "user_id": user_id, + "is_completed": True + } + + async def wait_for_completion(self) -> Optional[str]: + """ + Wait for linking completion or timeout. + Returns: + User ID if successful, None if timeout or failure + """ + while not self.is_completed(): + if time.time() > self.start_time + self.timeout: + self.terminate() + return None + await self._sleep(0.25) # Small delay between checks + user_id = self.get_user() + self.terminate() + if user_id == "login failed!": + return None + return user_id \ No newline at end of file diff --git a/packages/auth0-ai/auth0_ai/state/login_state.py b/packages/auth0-ai/auth0_ai/state/login_state.py new file mode 100644 index 0000000..0f27891 --- /dev/null +++ b/packages/auth0-ai/auth0_ai/state/login_state.py @@ -0,0 +1,63 @@ +from __future__ import annotations +import time +from typing import Any, Dict, Optional + +from .base_state import BaseState + +class LoginState(BaseState): + """ + Handles the state management for the login flow. + """ + + def __init__(self, state_store: Dict[str, Dict[str, Any]], state: str): + """ + Initialize login state tracker. + + Args: + state_store: Reference to the global state store + state: Unique state identifier for this login attempt + """ + super().__init__(state_store, state) + self.start_time = time.time() + self.timeout = 60 # Login timeout in seconds + + def is_completed(self) -> bool: + """Check if login flow is completed""" + return self.state_store[self.state].get("is_completed", False) + + def get_user(self) -> str | Dict[str, Any]: + """Get user information after login completion""" + return self.state_store.get(self.state, "login failed!") + + def complete(self, user_id: str) -> None: + """ + Mark login as complete with user information. + + Args: + user_id: ID of the authenticated user + """ + self.state_store[self.state] = { + "user_id": user_id, + "is_completed": True + } + + async def wait_for_completion(self) -> Optional[str]: + """ + Wait for login completion or timeout. + + Returns: + User ID if successful, None if timeout or failure + """ + while not self.is_completed(): + if time.time() > self.start_time + self.timeout: + self.terminate() + return None + await self._sleep(0.25) # Small delay between checks + + user_data = self.get_user() + self.terminate() + + if user_data == "login failed!": + return None + + return user_data.get("user_id") \ No newline at end of file diff --git a/packages/auth0-ai/auth0_ai/token_module/__init__.py b/packages/auth0-ai/auth0_ai/token_module/__init__.py new file mode 100644 index 0000000..5ca6543 --- /dev/null +++ b/packages/auth0-ai/auth0_ai/token_module/__init__.py @@ -0,0 +1,6 @@ +""" +Auth0 AI Token Management Module +Internal module for handling token operations and lifecycle. +""" +from .manager import TokenManager +__all__ = ["TokenManager"] diff --git a/packages/auth0-ai/auth0_ai/token_module/manager.py b/packages/auth0-ai/auth0_ai/token_module/manager.py new file mode 100644 index 0000000..63c7a8e --- /dev/null +++ b/packages/auth0-ai/auth0_ai/token_module/manager.py @@ -0,0 +1,183 @@ +from __future__ import annotations +from typing import Any, Dict +import time +from auth0.authentication import GetToken +from auth0.authentication.async_token_verifier import AsyncAsymmetricSignatureVerifier + + +class TokenManager: + """ + Manages token operations, including exchange, refresh, and validation. + """ + def __init__(self, auth_client: Any): + """ + Initialize token manager. + Args: + auth_client: Parent AIAuth instance + """ + self.auth_client = auth_client + self.token_verifier = AsyncAsymmetricSignatureVerifier( + jwks_url=f"https://{auth_client.domain}/.well-known/jwks.json" + ) + + def exchange_code_for_tokens(self, code: str) -> Dict[str, Any]: + """ + Exchange authorization code for tokens. + Args: + code: Authorization code from Auth0 + Returns: + Dict containing access token, refresh token, and ID token + """ + get_token = GetToken( + self.auth_client.domain, + self.auth_client.client_id, + self.auth_client.client_secret + ) + return get_token.authorization_code( + code=code, + redirect_uri=self.auth_client.redirect_uri, + grant_type="authorization_code" + ) + + def get_token_set(self, token_data: dict, existing_refresh_token: str | None = None) -> dict: + """ + Format token data with expiry time. + Args: + token_data: Raw token data from Auth0 + existing_refresh_token: Optional existing refresh token to preserve + Returns: + Formatted token data with expiry information + """ + return { + "access_token": token_data.get("access_token"), + "expires_at": {"epoch": int(time.time()) + token_data["expires_in"]}, + "refresh_token": token_data.get("refresh_token", existing_refresh_token), + "id_token": token_data.get("id_token"), + "scope": token_data.get("scope"), + } + + async def verify_id_token(self, id_token: str) -> Dict[str, Any]: + """ + Verify and decode ID token. + Args: + id_token: ID token to verify + Returns: + Decoded token claims + """ + return await self.token_verifier.verify_signature(id_token) + + def refresh_tokens(self, refresh_token: str) -> Dict[str, Any]: + """ + Refresh access token using refresh token. + Args: + refresh_token: Refresh token to use + Returns: + New token set + """ + token_client = GetToken( + self.auth_client.domain, + self.auth_client.client_id, + self.auth_client.client_secret + ) + return token_client.refresh_token(refresh_token=refresh_token) + + + def get_3rd_party_token(self, connection: str) -> dict[str, Any]: + return self.get_upstream_token(connection, self.get_refresh_token()) + + def tokeninfo(self) -> dict[str, Any]: + id_token = self.get_id_token() + self.parent.post() + data: dict[str, Any] = self.parent.get( + url=f"https://{self.parent.domain}/tokeninfo", + headers={"Authorization": f"Bearer {access_token}"}, + ) + return data + + def get_upstream_token( + self, + connection: str, + refresh_token: str, + additional_scopes: str | None = None + ) -> Dict[str, Any]: + """ + Get token for federated connection. + Args: + connection: Name of the connection (e.g., 'github') + refresh_token: Refresh token to use + additional_scopes: Optional additional scopes to request + Returns: + Token for the federated connection + """ + token_client = GetToken( + self.auth_client.domain, + self.auth_client.client_id, + self.auth_client.client_secret + ) + return token_client.federated_connection_access_token( + subject_token_type="urn:ietf:params:oauth:token-type:refresh_token", + subject_token=refresh_token, + requested_token_type="http://auth0.com/oauth/token-type/federated-connection-access-token", + connection=connection, + grant_type="urn:auth0:params:oauth:grant-type:token-exchange:federated-connection-access-token" + ) + + def get_userinfo(self, access_token: str) -> Dict[str, Any]: + """ + Get user information using access token. + Args: + access_token: Access token to use + Returns: + User profile information + """ + return self.auth_client.get( + url=f"https://{self.auth_client.domain}/userinfo", + headers={"Authorization": f"Bearer {access_token}"} + ) + + async def get_tokeninfo(self, id_token: str, access_token: str) -> Dict[str, Any]: + """ + Get detailed token information. + Args: + id_token: ID token + access_token: Access token + Returns: + Detailed token information + """ + return await self.auth_client.get( + url=f"https://{self.auth_client.domain}/tokeninfo", + headers={"Authorization": f"Bearer {access_token}"} + ) + + def validate_tokens(self, token_data: Dict[str, Any]) -> bool: + """ + Check if tokens are still valid. + Args: + token_data: Token data to validate + Returns: + True if tokens are valid, False otherwise + """ + if not token_data.get("expires_at"): + return False + expiry = token_data["expires_at"].get("epoch", 0) + return time.time() < expiry + + # Session Token Methods (used in User.py) + def get_id_token(self, user_id: str) -> Dict[str, Any]: + if user_id in self.auth_client.session_manager._get_stored_sessions(): + return (self.auth_client.session_manager.get_encrypted_session(user_id).get("tokens").get("id_token")) + else: + return {"user_id not found in session store"} + + def get_refresh_token(self, user_id: str) -> Dict[str, Any]: + if user_id in self.auth_client.session_manager._get_stored_sessions(): + return (self.auth_client.session_manager.get_encrypted_session(user_id).get("tokens").get("refresh_token")) + else: + return {"user_id not found in session store"} + + def get_access_token(self, user_id: str) -> Dict[str, Any]: + if user_id in self.auth_client.session_manager._get_stored_sessions(): + return (self.auth_client.session_manager.get_encrypted_session(user_id).get("tokens").get("access_token")) + else: + return {"user_id not found in session store"} + diff --git a/packages/auth0-ai/auth0_ai/utils/__init__.py b/packages/auth0-ai/auth0_ai/utils/__init__.py new file mode 100644 index 0000000..0a63561 --- /dev/null +++ b/packages/auth0-ai/auth0_ai/utils/__init__.py @@ -0,0 +1,7 @@ +""" +Auth0 AI Utilities Module +Provides utility functions and helpers for URL building and other common operations. +""" +from .url_builder import URLBuilder + +__all__ = ["URLBuilder"] diff --git a/packages/auth0-ai/auth0_ai/utils/url_builder.py b/packages/auth0-ai/auth0_ai/utils/url_builder.py new file mode 100644 index 0000000..0bf7a54 --- /dev/null +++ b/packages/auth0-ai/auth0_ai/utils/url_builder.py @@ -0,0 +1,117 @@ +from __future__ import annotations +from typing import Any, Dict +import json +from urllib.parse import urlencode +from auth0.authentication.pushed_authorization_requests import PushedAuthorizationRequests + +class URLBuilder: + """ + Handles construction of Auth0 authorization URLs and PAR requests. + Maintains original URL building logic from auth_client.py + """ + def __init__(self, auth_client: Any): + """ + Initialize URL builder. + Args: + auth_client: Parent AIAuth instance + """ + self.auth_client = auth_client + + def get_authorize_url( + self, + state: str, + connection: str | None = None, + scope: str | None = None, + additional_scopes: str | None = None, + **kwargs + ) -> str: + """ + Generate authorization URL. + Args: + state: State parameter for CSRF protection + connection: Auth0 connection to use + scope: OAuth scope + additional_scopes: Additional connection-specific scopes + **kwargs: Additional parameters to include in the URL + Returns: + Complete authorization URL + """ + # Base URL parameters + params = { + "response_type": "code", + "client_id": self.auth_client.client_id, + "redirect_uri": self.auth_client.redirect_uri, + "grant_type": "authorization_code", + "state": state + } + # Add optional parameters + if connection is not None: + params["connection"] = connection + if scope is not None: + params["scope"] = scope + if additional_scopes is not None: + params["connection_scope"] = additional_scopes + # Add any additional custom arguments + params.update(kwargs) + # Construct URL + query_string = urlencode(params) + return f"https://{self.auth_client.domain}/authorize?{query_string}" + + def get_authorize_par_url(self, state: str, request_uri: str) -> str: + """ + Generate PAR authorization URL. + Args: + state: State parameter for CSRF protection + request_uri: PAR request URI + Returns: + Complete PAR authorization URL + """ + params = { + "client_id": self.auth_client.client_id, + "state": state, + "request_uri": request_uri + } + query_string = urlencode(params) + return f"https://{self.auth_client.domain}/authorize?{query_string}" + async def create_par_request( + self, + state: str, + connection: str, + id_token: str, + scope: str | None = None, + **kwargs + ) -> Dict[str, Any]: + """ + Create a pushed authorization request. + Args: + state: State parameter for CSRF protection + connection: Connection to link + id_token: ID token for the primary user + scope: OAuth scope + **kwargs: Additional parameters + Returns: + PAR response containing request_uri + """ + par_client = PushedAuthorizationRequests( + self.auth_client.domain, + self.auth_client.client_id, + self.auth_client.client_secret + ) + # Prepare authorization details for account linking + auth_details = [{ + "type": "link_account", + "requested_connection": connection + }] + # Build PAR request + par_request = { + "response_type": "code", + "nonce": kwargs.get("nonce", "mynonce"), + "redirect_uri": self.auth_client.redirect_uri, + "audience": kwargs.get("audience", "my-account"), + "state": state, + "authorization_details": json.dumps(auth_details), + "scope": scope or "openid profile", + "id_token_hint": id_token, + **kwargs + } + return await par_client.pushed_authorization_request(**par_request) \ No newline at end of file diff --git a/packages/auth0-ai/pyproject.toml b/packages/auth0-ai/pyproject.toml new file mode 100644 index 0000000..aae17d1 --- /dev/null +++ b/packages/auth0-ai/pyproject.toml @@ -0,0 +1,25 @@ +[tool.poetry] +name = "auth0-ai" +version = "0.1.0" +description = "This package provides base auth capability for Auth0 AI." +license = "apache-2.0" +homepage = "https://auth0.com" +authors = [ + "Adeel Mustafa ", +] +readme = "README.md" +packages = [{ include = "auth0_ai" }] + +[tool.poetry.dependencies] +python = "^3.6" +auth0_python = "^4.8.0" +fastapi = {version = "^0.115.0", extras = ["standard"]} + +[tool.poetry.group.test.dependencies] +pytest-randomly = "^3.15.0" +pytest-asyncio = "^0.25.0" +pytest = "^8.2.0" + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" \ No newline at end of file