From abdd3c1d70d70d6f154149e7e1f3546da6cc0f05 Mon Sep 17 00:00:00 2001 From: mustafadeel Date: Tue, 11 Feb 2025 21:33:21 -0500 Subject: [PATCH 01/11] init --- packages/auth0-ai/README.md | 45 ++ packages/auth0-ai/auth0_ai/__init__.py | 4 + packages/auth0-ai/auth0_ai/ai_auth.py | 529 ++++++++++++++++++ packages/auth0-ai/auth0_ai/session_storage.py | 60 ++ packages/auth0-ai/pyproject.toml | 25 + 5 files changed, 663 insertions(+) create mode 100644 packages/auth0-ai/README.md create mode 100644 packages/auth0-ai/auth0_ai/__init__.py create mode 100644 packages/auth0-ai/auth0_ai/ai_auth.py create mode 100644 packages/auth0-ai/auth0_ai/session_storage.py create mode 100644 packages/auth0-ai/pyproject.toml diff --git a/packages/auth0-ai/README.md b/packages/auth0-ai/README.md new file mode 100644 index 0000000..c15a0ec --- /dev/null +++ b/packages/auth0-ai/README.md @@ -0,0 +1,45 @@ +# Auth0 AI + +This package provides base methods to use Auth0 with your AI use cases. + +## Installation + +```bash +# pip install langchain-auth0-ai +pip install git+https://github.com/atko-cic-lab/auth0-ai-python.git@main#subdirectory=packages/auth0-ai +``` + +## Running Tests + +1. **Install Dependencies** + + Use [Poetry](https://python-poetry.org/) to install the required dependencies: + + ```sh + $ poetry install + ``` + +2. **Run the tests** + + ```sh + $ poetry run pytest tests + ``` + +## Usage + +```python +tbd +``` + +--- + +

+ + + + Auth0 Logo + +

+

Auth0 is an easy to implement, adaptable authentication and authorization platform. To learn more checkout Why Auth0?

+

+This project is licensed under the Apache 2.0 license. See the LICENSE file for more info.

diff --git a/packages/auth0-ai/auth0_ai/__init__.py b/packages/auth0-ai/auth0_ai/__init__.py new file mode 100644 index 0000000..e953d42 --- /dev/null +++ b/packages/auth0-ai/auth0_ai/__init__.py @@ -0,0 +1,4 @@ +from .ai_auth import AIAuth + +__all__ = ["AIAuth"] + diff --git a/packages/auth0-ai/auth0_ai/ai_auth.py b/packages/auth0-ai/auth0_ai/ai_auth.py new file mode 100644 index 0000000..7557696 --- /dev/null +++ b/packages/auth0-ai/auth0_ai/ai_auth.py @@ -0,0 +1,529 @@ +from __future__ import annotations + +from typing import Any + +from auth0_python import AuthenticationBase, GetToken, AsyncAsymmetricSignatureVerifier, PushedAuthorizationRequests + +import webbrowser +import urllib.parse + +import os + +import jwt # PyJWT for signing cookies +from fastapi import FastAPI, Request, HTTPException, Response +from fastapi.responses import JSONResponse + +from typing import Any, Dict + +import uvicorn +import secrets +import threading + +import time +import json + + +from .session_storage import SessionStorage + + +class AIAuth(AuthenticationBase): + + def __init__( + self, + domain: str | None = None, + client_id: str | None = None, + client_secret: str | None = None, + redirect_uri: str | None = None, + *args, **kwargs): + + self.domain = domain or os.environ.get("AUTH0_DOMAIN") + self.client_id = client_id or os.environ.get("AUTH0_CLIENT_ID") + self.client_secret = client_secret or os.environ.get( + "AUTH0_CLIENT_SECRET") + self.redirect_uri = redirect_uri or os.environ.get( + "AUTH0_REDIRECT_URI") + + """Initialize AIAuth and set up the middleware app with the callback route.""" + super().__init__( + domain=self.domain, + client_id=self.client_id, + client_secret=self.client_secret, + *args, **kwargs) # Initialize parent class + + @property + def domain(self): + return self._domain + + @domain.setter + def domain(self, val): + if not val: + raise ValueError( + "domain cannot be empty. you can also set AUTH0_DOMAIN value in .env file") + self._domain = val + + @property + def client_id(self): + return self._client_id + + @client_id.setter + def client_id(self, val): + if not val: + raise ValueError( + "client_id cannot be empty. you can also AUTH0_CLIENT_ID value in .env file") + self._client_id = val + + @property + def client_secret(self): + return self._client_secret + + @client_secret.setter + def client_secret(self, val): + if not val: + raise ValueError( + "client_secret cannot be empty. you can also AUTH0_CLIENT_SECRET value in .env file") + self._client_secret = val + + @property + def redirect_uri(self): + return self._redirect_uri + + @redirect_uri.setter + def redirect_uri(self, val): + if not val: + raise ValueError( + "redirect_uri cannot be empty. you can also AUTH0_REDIRECT_URI value in .env file") + self._redirect_uri = val + + # Initialize token verifier + jwk_url = f"https://{self.domain}/.well-known/jwks.json" + self.token_verifier = AsyncAsymmetricSignatureVerifier( + jwks_url=jwk_url) + + # Initialize FastAPI app + self.app = FastAPI() + # Temporary store for state values + self.state_store: Dict[str, Dict[bool, str]] = {} + # or secrets.token_urlsafe(32) # Secure random secret key + self.secret_key = os.environ.get("AUTH0_SECRET_KEY") + + # Register the callback route + @self.app.get("/auth/callback") + async def manage_callback(request: Request, response: Response): + """Parses and validates callback URL query parameters.""" + query_params = request.query_params + required_keys = {"code", "state"} + + if query_params.get("error"): + error_description = query_params.get( + "error_description", "Unknown error occurred.") + if query_params.get("state"): + del self.state_store[query_params.get("state")] + raise HTTPException(status_code=400, detail=error_description) + + if not required_keys.issubset(query_params.keys()): + raise HTTPException( + status_code=400, detail="Missing required query parameters.") + + received_state = query_params["state"] + + # Validate state to prevent CSRF attacks + if received_state not in self.state_store: + raise HTTPException( + status_code=400, detail="Invalid or missing state parameter.") + + # Extract code value from query string + received_code = query_params["code"] + + auth0_tokens = self._exchange_code_for_tokens(received_code) + + if auth0_tokens: + + cookie_data = await self._set_encrypted_session(auth0_tokens, state=received_state) + + response.set_cookie( + key="session", + value=cookie_data, + httponly=True, # Prevent JavaScript access + # secure=True, # Send only over HTTPS + samesite="Lax", # Protect against CSRF + # set expiry based on access token expiry + max_age=auth0_tokens["expires_in"], + ) + + # Remove state after validation (one-time use) + # del self.state_store[received_state] + + return {"message": "successul. you can now close this window"} + + @self.app.get("/auth/get_user") + async def get_user(request: Request): + """Reads the session cookie and extracts user info.""" + auth_cookie = request.cookies.get("session") + + if not auth_cookie: + raise HTTPException( + status_code=401, detail="Missing session cookie.") + + try: + # Decode the JWT stored in the session cookie + decoded_data = jwt.decode( + auth_cookie, self.secret_key, algorithms=["HS256"]) + + # Extract the user ID (sub) from the decoded JWT + user_id = decoded_data["user_id"] + + if not user_id: + raise HTTPException( + status_code=400, detail="Invalid session cookie: Missing 'sub' claim.") + + return (JSONResponse(content=decoded_data)) + + except jwt.ExpiredSignatureError: + raise HTTPException( + status_code=401, detail="Session cookie has expired.") + except jwt.InvalidTokenError: + raise HTTPException( + status_code=401, detail="Invalid session cookie.") + + # Start middleware server in a separate thread + self.host = urllib.parse.urlparse(self.redirect_uri).hostname + self.port = urllib.parse.urlparse(self.redirect_uri).port + self._start_server() + + def _start_server(self): + """Runs FastAPI as the middleware inside a separate thread.""" + server_thread = threading.Thread( + target=uvicorn.run, + args=(self.app,), + kwargs={"host": self.host, "port": self.port, "log_level": "info"}, + daemon=True # Daemon mode so it exits when the main thread exits + ) + server_thread.start() + + def _generate_state(self) -> str: + """Generate a secure random state and store it for validation.""" + state = secrets.token_urlsafe(16) # Generate a random state + self.state_store[state] = True # Store it temporarily + return state + + def _get_token_set(self, token_data: str, existing_refresh_token: str | None = None) -> dict: + + token_data = { + "access_token": token_data.get("access_token"), "expires_at": {"epoch": int(time.time())+token_data["expires_in"]}, + "refresh_token": token_data.get("refresh_token", existing_refresh_token), + "id_token": token_data.get("id_token"), + "scope": token_data.get("scope"), + } + return token_data + + def _get_linked_details(self, token_data: dict, existing_linked_connections: list[str] | None = None) -> list[str]: + """Extracts unique link_with values from authorization_details if type is 'account_linking' + and appends them to existing_linked_connections, avoiding duplicates. + """ + authz_details = token_data.get("authorization_details", []) + # Use a set to enforce uniqueness + linked_connections = set(existing_linked_connections or []) + + for item in authz_details: + if item.get("type") == "account_linking": + link_with = item.get("linkParams", {}).get("link_with") + if link_with: # Ensure link_with is not None + # Add to set to avoid duplicates + linked_connections.add(link_with) + + # Convert back to a list before returning + return list(linked_connections) + + async def _set_encrypted_session(self, token_data, state: str | None = None) -> str: + session_store = SessionStorage() + try: + decoded_id_token = await self.token_verifier.verify_signature(token_data["id_token"]) + user_id = decoded_id_token.get("sub") # Primary Key + if not user_id: + raise HTTPException( + status_code=400, detail="ID token missing 'sub' claim.") + except Exception as e: + raise HTTPException( + status_code=400, detail=f"Invalid ID token: {str(e)}") + + existing_encrypted_session = session_store._get_stored_session(user_id) + existing_linked_connections = {} + existing_refresh_token = {} + + if existing_encrypted_session: + # found existing session, check if there is a refresh token to keep + existing_session = self._get_encrypted_session(user_id) + existing_refresh_token = existing_session.get( + "tokens").get("refresh_token", None) + existing_linked_connections = existing_session.get( + "linked_connections") + + session_data = {} + session_data = {"user_id": user_id, + "tokens": self._get_token_set(token_data, existing_refresh_token), + "linked_connections": self._get_linked_details(token_data, existing_linked_connections) + } + + encrypted_session_data = jwt.encode( + session_data, self.secret_key, algorithm="HS256") + + # Stored in memory & auto-persisted + session_store._set_stored_session( + user_id=user_id, encrypted_session_data=encrypted_session_data) + + self.state_store[state] = {"user_id": user_id} + + # print("Session created/updated for:",user_id) + return encrypted_session_data + + def _get_encrypted_session(self, user_id): + session_store = SessionStorage() + encrypted_session = session_store._get_stored_session(user_id) + + if not encrypted_session: + return {"not found"} + + try: + # Decode the JWT stored in the session cookie + decoded_data = jwt.decode( + encrypted_session, self.secret_key, algorithms=["HS256"]) + + # Extract the user ID (sub) from the decoded JWT + user_id = decoded_data["user_id"] + + token_expiry = decoded_data.get("tokens", {}).get( + "expires_at", {}).get("epoch") + + if token_expiry > int(time.time()): + return decoded_data + else: + refresh_token = decoded_data.get( + "tokens", {}).get("refresh_token") + if refresh_token: + self._update_encrypted_session(user_id, refresh_token) + else: + session_store._delete_stored_session(user_id) + return {"session expired"} + + except jwt.ExpiredSignatureError: + return {"Session cookie has expired."} + except jwt.InvalidTokenError: + return {"Invalid session."} + + def _update_encrypted_session(self, user_id, refresh_token): + token_manager = GetToken() + updated_tokens = token_manager.refresh_token( + refresh_token=refresh_token) + + if updated_tokens: + self._set_encrypted_session(updated_tokens) + + def get_authorize_url( + self, + state: str, + connection: str | None = None, + scope: str | None = None, + additional_scopes: str | None = None, + **kwargs, + ) -> str: + + base_url = ( + f"https://{self.domain}/authorize?" + f"response_type=code&" + f"client_id={self.client_id}&" + f"redirect_uri={self.redirect_uri}&" + f"grant_type=authorization_code&" + f"state={state}&" + ) + + if connection is not None: + base_url += f"connection={connection}&" + + if scope is not None: + base_url += f"scope={scope}&" + + if additional_scopes is not None: + base_url += f"connection_scope={additional_scopes}&" + + # Add any additional custom arguments passed via kwargs + custom_args = "&".join( + [f"{key}={value}" for key, value in kwargs.items()]) + if custom_args: + base_url += custom_args + + return base_url + + def get_authorize_par_url( + self, + state: str, + request_uri: str, + ) -> str: + + base_url = ( + f"https://{self.domain}/authorize?" + f"client_id={self.client_id}&" + f"state={state}&" + f"request_uri={request_uri}" + ) + + return base_url + + def _exchange_code_for_tokens( + self, + code: str, + ) -> dict[str, Any]: + + get_token = GetToken(self.domain, self.client_id, self.client_secret) + + token_info = get_token.authorization_code( + code=code, + redirect_uri=self.redirect_uri, + grant_type="authorization_code" + ) + + return token_info + + def get_upstream_token( + self, + connection, + refresh_token: str, + additional_scopes: str | None = None, + ) -> dict[str, Any]: + + fcat = GetToken(self.domain, self.client_id, self.client_secret) + + x = fcat.federated_connection_access_token( + subject_token_type="urn:ietf:params:oauth:token-type:refresh_token", + subject_token=refresh_token, + requested_token_type="http://auth0.com/oauth/token-type/federated-connection-access-token", + connection=connection, + grant_type="urn:auth0:params:oauth:grant-type:token-exchange:federated-connection-access-token" + ) + + return (x) + + def get_session_details(self, user_id: str) -> dict[str, Any]: + session_store = SessionStorage() + if user_id in session_store._get_stored_sessions(): + return (self._get_encrypted_session(user_id)) + else: + return {"user_id not found in session store"} + + async def login(self, connection: str | None = None, scope: str | None = None, **kwargs) -> str: + + if scope is None: + scope = "openid profile email" + + state = self._generate_state() + + class LoginState: + def __init__(self, state_store, state): + self.state_store = state_store + self.state = state + + def is_completed(self) -> bool: + if self.state_store.get(self.state, False) == True: + return False + else: + return True + + def get_user(self) -> str: + return self.state_store.get(self.state, "login failed!") + + login_state = LoginState(self.state_store, state) + + auth_url = self.get_authorize_url( + state=state, connection=connection, scope=scope, **kwargs) + + # this initiates the login flow, and if successful the session is created + try: + webbrowser.open(auth_url) + except webbrowser.Error: + print("Please navigate here: {auth_url}") + + # check of compleition of login flow + while not login_state.is_completed(): + # not sure if this needed, but can save some unecessary polling time + time.sleep(0.05) + user_id = login_state.get_user() + + if user_id == "login failed!": + successul_login = False + else: + successul_login = True + + login_response = { + "is_successful": successul_login, + "user_id": user_id, + } + return login_response + + async def link(self, primary_user_id: str, connection: str, scope: str | None = None, **kwargs) -> str: + + state = self._generate_state() + + class LinkState: + def __init__(self, state_store, state): + self.state_store = state_store + self.state = state + + def is_completed(self) -> bool: + if self.state_store.get(self.state, False) == True: + return False + else: + return True + + def get_user(self) -> str: + return self.state_store.get(self.state, "login failed!").get("user_id") + + par_client = PushedAuthorizationRequests( + self.domain, self.client_id, self.client_secret) + + link_state = LinkState(self.state_store, state) + + y = par_client.pushed_authorization_request( + response_type="code", + redirect_uri=self.redirect_uri, + audience="https://accounts.auth101.dev/me/", + connection=connection, + state=state, + authorization_details=json.dumps([ + {"type": "account_linking", "linkParams": + {"primary_user_id": primary_user_id, + "link_with": connection, } + } + ]), + scope=scope, + prompt="login", + ) + + request_uri = y.get('request_uri') + + if request_uri: + auth_url = self.get_authorize_par_url( + state=state, request_uri=request_uri) + + # this initiates the login flow for the connection to link + try: + webbrowser.open(auth_url) + except webbrowser.Error: + print("Please navigate here: {auth_url}") + + # check of compleition of login flow + while not link_state.is_completed(): + # not sure if this needed, but can save some unecessary polling time + time.sleep(0.05) + user_id = link_state.get_user() + + if user_id == "login failed!": + successul_login = False + else: + successul_login = True + + link_response = { + "is_successful": successul_login, + "user_id": user_id, + } + return link_response + else: + return ("linking error") diff --git a/packages/auth0-ai/auth0_ai/session_storage.py b/packages/auth0-ai/auth0_ai/session_storage.py new file mode 100644 index 0000000..8b94c4f --- /dev/null +++ b/packages/auth0-ai/auth0_ai/session_storage.py @@ -0,0 +1,60 @@ +import shelve +from typing import Any + + +class SessionStorage: + + def __init__( + self, + use_local_cache: bool = True, + get_sessions=None, + get_session=None, + set_session=None, + delete_session=None, + ): + + self.get_sessions = get_sessions + self.get_session = get_session + self.set_session = set_session + self.delete_session = delete_session + + self.use_local_cache = use_local_cache or os.environ.get( + "AUTH0_USE_LOCAL_CACHE") + + @property + def use_local_cache(self): + return self._use_local_cache + + @use_local_cache.setter + def use_local_cache(self, val): + self._use_local_cache = val + + def _get_stored_sessions(self) -> Any: + if (self.use_local_cache): + with shelve.open(".sessions_cache") as sessions: + return list(sessions.keys()) + else: + return self.get_session() + + def _get_stored_session(self, user_id: str) -> str: + if (self.use_local_cache): + with shelve.open(".sessions_cache") as sessions: + return sessions.get(user_id) + else: + return self.get_session() + + def _set_stored_session(self, user_id, encrypted_session_data): + if (self.use_local_cache): + with shelve.open(".sessions_cache") as sessions: + sessions[user_id] = encrypted_session_data + sessions.sync() + else: + self.set_session() + + def _delete_stored_session(self, user_id): + if (self.use_local_cache): + with shelve.open(".sessions_cache") as sessions: + if user_id in sessions: + del sessions[user_id] + else: + self.del_session() diff --git a/packages/auth0-ai/pyproject.toml b/packages/auth0-ai/pyproject.toml new file mode 100644 index 0000000..b7c4b17 --- /dev/null +++ b/packages/auth0-ai/pyproject.toml @@ -0,0 +1,25 @@ +[tool.poetry] +name = "langchain-auth0-ai" +version = "0.1.0" +description = "This package integrates LangChain with Auth0 AI for enhanced document retrieval capabilities." +license = "apache-2.0" +homepage = "https://auth0.com" +authors = [ + "Jose F. Romaniello ", + "Javier Centurion ", +] +readme = "README.md" + +[tool.poetry.dependencies] +python = "^3.11" +openfga-sdk = "^0.9.0" +langchain = "^0.3.11" + +[tool.poetry.group.test.dependencies] +pytest-randomly = "^3.15.0" +pytest-asyncio = "^0.25.0" +pytest = "^8.2.0" + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" From 3b5c62fabfc4580f734b8a29cd1f994a93ddc3da Mon Sep 17 00:00:00 2001 From: mustafadeel Date: Tue, 11 Feb 2025 21:54:23 -0500 Subject: [PATCH 02/11] fixing package name --- packages/auth0-ai/README.md | 3 +-- packages/auth0-ai/pyproject.toml | 12 ++++++------ 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/packages/auth0-ai/README.md b/packages/auth0-ai/README.md index c15a0ec..527daee 100644 --- a/packages/auth0-ai/README.md +++ b/packages/auth0-ai/README.md @@ -5,8 +5,7 @@ This package provides base methods to use Auth0 with your AI use cases. ## Installation ```bash -# pip install langchain-auth0-ai -pip install git+https://github.com/atko-cic-lab/auth0-ai-python.git@main#subdirectory=packages/auth0-ai +pip install git+https://github.com/mustafadeel/auth0-ai-python.git@main#subdirectory=packages/auth0-ai ``` ## Running Tests diff --git a/packages/auth0-ai/pyproject.toml b/packages/auth0-ai/pyproject.toml index b7c4b17..fdb3ad5 100644 --- a/packages/auth0-ai/pyproject.toml +++ b/packages/auth0-ai/pyproject.toml @@ -1,19 +1,19 @@ [tool.poetry] -name = "langchain-auth0-ai" +name = "auth0-ai" version = "0.1.0" -description = "This package integrates LangChain with Auth0 AI for enhanced document retrieval capabilities." +description = "This package provides base auth capability for Auth0 AI." license = "apache-2.0" homepage = "https://auth0.com" authors = [ - "Jose F. Romaniello ", - "Javier Centurion ", + "Adeel Mustafa ", ] readme = "README.md" [tool.poetry.dependencies] python = "^3.11" -openfga-sdk = "^0.9.0" -langchain = "^0.3.11" +auth0_python = "^4.8.0" +fastapi[standard] = "^0.115.0" +typing [tool.poetry.group.test.dependencies] pytest-randomly = "^3.15.0" From 98ad2a662291fd07f6704de9abd7008392ac8080 Mon Sep 17 00:00:00 2001 From: mustafadeel Date: Tue, 11 Feb 2025 21:58:37 -0500 Subject: [PATCH 03/11] version fixes --- packages/auth0-ai/pyproject.toml | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/packages/auth0-ai/pyproject.toml b/packages/auth0-ai/pyproject.toml index fdb3ad5..4240e48 100644 --- a/packages/auth0-ai/pyproject.toml +++ b/packages/auth0-ai/pyproject.toml @@ -10,15 +10,14 @@ authors = [ readme = "README.md" [tool.poetry.dependencies] -python = "^3.11" -auth0_python = "^4.8.0" -fastapi[standard] = "^0.115.0" -typing +python = ">=3.6" +auth0_python = ">=4.8.0" +fastapi[standard] = ">=0.115.0" [tool.poetry.group.test.dependencies] -pytest-randomly = "^3.15.0" -pytest-asyncio = "^0.25.0" -pytest = "^8.2.0" +pytest-randomly = ">=3.15.0" +pytest-asyncio = ">=0.25.0" +pytest = ">=8.2.0" [build-system] requires = ["poetry-core"] From 69abe2d1deea060c81dd1ceab03d176eb8c666e3 Mon Sep 17 00:00:00 2001 From: mustafadeel Date: Thu, 13 Feb 2025 23:01:06 -0500 Subject: [PATCH 04/11] added a User class and updated linking --- packages/auth0-ai/auth0_ai/__init__.py | 5 +- packages/auth0-ai/auth0_ai/ai_auth.py | 278 +++++++++++++++++-------- packages/auth0-ai/pyproject.toml | 15 +- 3 files changed, 205 insertions(+), 93 deletions(-) diff --git a/packages/auth0-ai/auth0_ai/__init__.py b/packages/auth0-ai/auth0_ai/__init__.py index e953d42..874d9eb 100644 --- a/packages/auth0-ai/auth0_ai/__init__.py +++ b/packages/auth0-ai/auth0_ai/__init__.py @@ -1,4 +1,3 @@ -from .ai_auth import AIAuth - -__all__ = ["AIAuth"] +from .ai_auth import AIAuth, User +__all__ = ["AIAuth", "User"] diff --git a/packages/auth0-ai/auth0_ai/ai_auth.py b/packages/auth0-ai/auth0_ai/ai_auth.py index 7557696..31b3565 100644 --- a/packages/auth0-ai/auth0_ai/ai_auth.py +++ b/packages/auth0-ai/auth0_ai/ai_auth.py @@ -2,7 +2,12 @@ from typing import Any -from auth0_python import AuthenticationBase, GetToken, AsyncAsymmetricSignatureVerifier, PushedAuthorizationRequests +from auth0.authentication.base import AuthenticationBase +from auth0.authentication import GetToken +from auth0.authentication.async_token_verifier import AsyncAsymmetricSignatureVerifier +from auth0.authentication.pushed_authorization_requests import PushedAuthorizationRequests + +from .session_storage import SessionStorage import webbrowser import urllib.parse @@ -11,7 +16,7 @@ import jwt # PyJWT for signing cookies from fastapi import FastAPI, Request, HTTPException, Response -from fastapi.responses import JSONResponse +from fastapi.responses import JSONResponse, RedirectResponse from typing import Any, Dict @@ -23,9 +28,6 @@ import json -from .session_storage import SessionStorage - - class AIAuth(AuthenticationBase): def __init__( @@ -106,6 +108,9 @@ def redirect_uri(self, val): # or secrets.token_urlsafe(32) # Secure random secret key self.secret_key = os.environ.get("AUTH0_SECRET_KEY") + # Initialize the SessionStore + self.session_store = SessionStorage() + # Register the callback route @self.app.get("/auth/callback") async def manage_callback(request: Request, response: Response): @@ -150,11 +155,29 @@ async def manage_callback(request: Request, response: Response): max_age=auth0_tokens["expires_in"], ) - # Remove state after validation (one-time use) - # del self.state_store[received_state] + user_id = self.state_store[received_state].get( + "user_id", "failed") + self.state_store[received_state] = { + "user_id": user_id, "is_completed": True} return {"message": "successul. you can now close this window"} + # Register the callback route + + # Register the login route + @self.app.get("/auth/login") + async def manage_login(request: Request, response: Response): + # if scope is None: + scope = "openid profile email" + connection = "Username-Password-Authentication" + + state = self._generate_state() + + auth_url = self.get_authorize_url( + state=state, connection=connection, scope=scope, **kwargs) + + return RedirectResponse(url=auth_url, status_code=302) + @self.app.get("/auth/get_user") async def get_user(request: Request): """Reads the session cookie and extracts user info.""" @@ -203,7 +226,8 @@ def _start_server(self): def _generate_state(self) -> str: """Generate a secure random state and store it for validation.""" state = secrets.token_urlsafe(16) # Generate a random state - self.state_store[state] = True # Store it temporarily + # Store it temporarily and flag it as false as we havent received it back as yet + self.state_store[state] = {"is_competed": False} return state def _get_token_set(self, token_data: str, existing_refresh_token: str | None = None) -> dict: @@ -235,18 +259,24 @@ def _get_linked_details(self, token_data: dict, existing_linked_connections: lis return list(linked_connections) async def _set_encrypted_session(self, token_data, state: str | None = None) -> str: - session_store = SessionStorage() - try: - decoded_id_token = await self.token_verifier.verify_signature(token_data["id_token"]) - user_id = decoded_id_token.get("sub") # Primary Key - if not user_id: + id_token = token_data.get("id_token", "") + if id_token: + try: + decoded_id_token = await self.token_verifier.verify_signature(id_token) + user_id = decoded_id_token.get("sub") # Primary Key + if not user_id: + raise HTTPException( + status_code=400, detail="ID token missing 'sub' claim.") + except Exception as e: raise HTTPException( - status_code=400, detail="ID token missing 'sub' claim.") - except Exception as e: - raise HTTPException( - status_code=400, detail=f"Invalid ID token: {str(e)}") + status_code=400, detail=f"Invalid ID token: {str(e)}") + else: + # this can happen in the case of the linking flow if the openid scope is not requested + # and no ID token is issued + user_id = self.state_store[state].get("user_id") - existing_encrypted_session = session_store._get_stored_session(user_id) + existing_encrypted_session = self.session_store._get_stored_session( + user_id) existing_linked_connections = {} existing_refresh_token = {} @@ -268,7 +298,7 @@ async def _set_encrypted_session(self, token_data, state: str | None = None) -> session_data, self.secret_key, algorithm="HS256") # Stored in memory & auto-persisted - session_store._set_stored_session( + self.session_store._set_stored_session( user_id=user_id, encrypted_session_data=encrypted_session_data) self.state_store[state] = {"user_id": user_id} @@ -277,8 +307,7 @@ async def _set_encrypted_session(self, token_data, state: str | None = None) -> return encrypted_session_data def _get_encrypted_session(self, user_id): - session_store = SessionStorage() - encrypted_session = session_store._get_stored_session(user_id) + encrypted_session = self.session_store._get_stored_session(user_id) if not encrypted_session: return {"not found"} @@ -302,7 +331,7 @@ def _get_encrypted_session(self, user_id): if refresh_token: self._update_encrypted_session(user_id, refresh_token) else: - session_store._delete_stored_session(user_id) + self.session_store._delete_stored_session(user_id) return {"session expired"} except jwt.ExpiredSignatureError: @@ -403,13 +432,18 @@ def get_upstream_token( return (x) def get_session_details(self, user_id: str) -> dict[str, Any]: - session_store = SessionStorage() - if user_id in session_store._get_stored_sessions(): + if user_id in self.session_store._get_stored_sessions(): return (self._get_encrypted_session(user_id)) else: return {"user_id not found in session store"} - async def login(self, connection: str | None = None, scope: str | None = None, **kwargs) -> str: + def get_session(self, user: User) -> dict[str, Any]: + if User.user_id in self.session_store._get_stored_sessions(): + return (self._get_encrypted_session(User.user_id)) + else: + return {"user_id not found in session store"} + + async def interactive_login(self, connection: str | None = None, scope: str | None = None, **kwargs) -> User: if scope is None: scope = "openid profile email" @@ -420,16 +454,17 @@ class LoginState: def __init__(self, state_store, state): self.state_store = state_store self.state = state + self.start_time = time.time() def is_completed(self) -> bool: - if self.state_store.get(self.state, False) == True: - return False - else: - return True + return self.state_store[self.state].get("is_completed") def get_user(self) -> str: return self.state_store.get(self.state, "login failed!") + def terminate(self): + del self.state_store[self.state] + login_state = LoginState(self.state_store, state) auth_url = self.get_authorize_url( @@ -444,21 +479,24 @@ def get_user(self) -> str: # check of compleition of login flow while not login_state.is_completed(): # not sure if this needed, but can save some unecessary polling time - time.sleep(0.05) + if (time.time() < login_state.start_time + 60): + time.sleep(0.25) + else: + # login has timed out, we can clean up state + login_state.terminate() + return ("login timeout") + user_id = login_state.get_user() + # no longer need to retain the state anymore, we can clean it up + login_state.terminate() + if user_id == "login failed!": - successul_login = False + return "login failed" else: - successul_login = True + return User(self, user_id=user_id.get("user_id")) - login_response = { - "is_successful": successul_login, - "user_id": user_id, - } - return login_response - - async def link(self, primary_user_id: str, connection: str, scope: str | None = None, **kwargs) -> str: + async def link(self, primary_user_id: str, connection: str, id_token: str, scope: str | None = None, **kwargs) -> str: state = self._generate_state() @@ -466,64 +504,138 @@ class LinkState: def __init__(self, state_store, state): self.state_store = state_store self.state = state + self.start_time = time.time() def is_completed(self) -> bool: - if self.state_store.get(self.state, False) == True: - return False - else: - return True + return self.state_store[self.state].get("is_completed", False) def get_user(self) -> str: return self.state_store.get(self.state, "login failed!").get("user_id") + def set_user(self, user_id: str) -> None: + is_comeplted = self.state_store[self.state].get( + "is_competed", False) + self.state_store[self.state] = { + "is_completed": is_comeplted, "user_id": user_id} + + def terminate(self): + del self.state_store[self.state] + par_client = PushedAuthorizationRequests( self.domain, self.client_id, self.client_secret) link_state = LinkState(self.state_store, state) + link_state.set_user(primary_user_id) - y = par_client.pushed_authorization_request( - response_type="code", - redirect_uri=self.redirect_uri, - audience="https://accounts.auth101.dev/me/", - connection=connection, - state=state, - authorization_details=json.dumps([ - {"type": "account_linking", "linkParams": - {"primary_user_id": primary_user_id, - "link_with": connection, } - } - ]), - scope=scope, - prompt="login", - ) - - request_uri = y.get('request_uri') - - if request_uri: - auth_url = self.get_authorize_par_url( - state=state, request_uri=request_uri) + try: + y = par_client.pushed_authorization_request( + response_type="code", + nonce="mynonce", + redirect_uri=self.redirect_uri, + audience="my-account", + state=state, + authorization_details=json.dumps([ + {"type": "link_account", "requested_connection": connection} + ]), + scope="openid profile", + id_token_hint=id_token, + ) + + request_uri = y.get('request_uri') + + if request_uri: + auth_url = self.get_authorize_par_url( + state=state, request_uri=request_uri) + + # this initiates the login flow for the connection to link + try: + webbrowser.open(auth_url) + except webbrowser.Error: + print("Please navigate here:", auth_url) + + # check of compleition of login flow + while not link_state.is_completed(): + # not sure if this needed, but can save some unecessary polling time + if (time.time() < link_state.start_time + 60): + time.sleep(0.05) + else: + link_state.terminate() + return ("linking timeout") + + user_id = link_state.get_user() + + # no longer need state, we can clean it up + link_state.terminate() + + if user_id == "login failed!": + successul_login = False + else: + successul_login = True - # this initiates the login flow for the connection to link - try: - webbrowser.open(auth_url) - except webbrowser.Error: - print("Please navigate here: {auth_url}") - - # check of compleition of login flow - while not link_state.is_completed(): - # not sure if this needed, but can save some unecessary polling time - time.sleep(0.05) - user_id = link_state.get_user() - - if user_id == "login failed!": - successul_login = False + link_response = { + "is_successful": successul_login, + "user_id": user_id, + } else: - successul_login = True + link_response = { + "is_successful": False, + "user_id": user_id, + } + except Exception as error: + print(error) link_response = { - "is_successful": successul_login, - "user_id": user_id, + "is_successful": False, + "user_id": primary_user_id, } - return link_response + + return link_response + + +class User(AIAuth): + + def __init__(self, parent, user_id: str): + self.parent = parent + self.user_id = user_id + self.client = parent.client + + async def link(self, connection: str | None = None, scope: str | None = None, **kwargs) -> str: + return await self.parent.link(connection=connection, primary_user_id=self.user_id, id_token=self.get_id_token(), scope=scope, **kwargs) + + def get_id_token(self) -> str: + if self.user_id in self.parent.session_store._get_stored_sessions(): + return (self.parent._get_encrypted_session(self.user_id).get("tokens").get("id_token")) + else: + return {"user_id not found in session store"} + + def get_access_token(self) -> str: + if self.user_id in self.parent.session_store._get_stored_sessions(): + return (self.parent._get_encrypted_session(self.user_id).get("tokens").get("access_token")) + else: + return {"user_id not found in session store"} + + def get_refresh_token(self) -> str: + if self.user_id in self.parent.session_store._get_stored_sessions(): + return (self.parent._get_encrypted_session(self.user_id).get("tokens").get("refresh_token")) else: - return ("linking error") + return {"user_id not found in session store"} + + def get_3rd_party_token(self, connection: str) -> dict[str, Any]: + return self.parent.get_upstream_token(connection, self.get_refresh_token()) + + def tokeninfo(self) -> dict[str, Any]: + id_token = self.get_id_token() + self.parent.post() + data: dict[str, Any] = self.parent.get( + url=f"https://{self.parent.domain}/tokeninfo", + headers={"Authorization": f"Bearer {access_token}"}, + ) + return data + + def userinfo(self, access_token: str | None = None) -> dict[str, Any]: + access_token = access_token or self.get_access_token() + data: dict[str, Any] = self.parent.get( + url=f"https://{self.parent.domain}/userinfo", + headers={"Authorization": f"Bearer {access_token}"}, + ) + return data diff --git a/packages/auth0-ai/pyproject.toml b/packages/auth0-ai/pyproject.toml index 4240e48..aae17d1 100644 --- a/packages/auth0-ai/pyproject.toml +++ b/packages/auth0-ai/pyproject.toml @@ -8,17 +8,18 @@ authors = [ "Adeel Mustafa ", ] readme = "README.md" +packages = [{ include = "auth0_ai" }] [tool.poetry.dependencies] -python = ">=3.6" -auth0_python = ">=4.8.0" -fastapi[standard] = ">=0.115.0" +python = "^3.6" +auth0_python = "^4.8.0" +fastapi = {version = "^0.115.0", extras = ["standard"]} [tool.poetry.group.test.dependencies] -pytest-randomly = ">=3.15.0" -pytest-asyncio = ">=0.25.0" -pytest = ">=8.2.0" +pytest-randomly = "^3.15.0" +pytest-asyncio = "^0.25.0" +pytest = "^8.2.0" [build-system] requires = ["poetry-core"] -build-backend = "poetry.core.masonry.api" +build-backend = "poetry.core.masonry.api" \ No newline at end of file From 90f405dbae1dc648eb580e1863567ade408d6476 Mon Sep 17 00:00:00 2001 From: mustafadeel Date: Thu, 13 Feb 2025 23:07:35 -0500 Subject: [PATCH 05/11] Updated README readme typo typo in readme --- packages/auth0-ai/README.md | 40 ++++++++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/packages/auth0-ai/README.md b/packages/auth0-ai/README.md index 527daee..0542163 100644 --- a/packages/auth0-ai/README.md +++ b/packages/auth0-ai/README.md @@ -26,8 +26,46 @@ pip install git+https://github.com/mustafadeel/auth0-ai-python.git@main#subdirec ## Usage +Create a .env file with the following deatils: + +``` +AUTH0_DOMAIN='<>' +AUTH0_CLIENT_ID='<>' +AUTH0_CLIENT_SECRET='<>' +AUTH0_REDIRECT_URI='<>' +AUTH0_SECRET_KEY='ALongRandomlyGeneratedString' +``` + +Create a python script for an interactive login, link and tool token example: + ```python -tbd +from dotenv import find_dotenv, load_dotenv + +import asyncio + +from auth0_ai import AIAuth, User + +ENV_FILE = find_dotenv() +if ENV_FILE: + load_dotenv(ENV_FILE) + +auth_client = AIAuth() + +async def login(): + return await auth_client.interactive_login(connection="Username-Password-Authentication", scope="openid email offline_access") + +async def link(user_id, connection): + linked = await auth_client.link(primary_user_id=user_id, connection=connection, scope="openid email") + return linked + +user1 = asyncio.run(login()) + +print("-" * 20) +print("USER DETAILS:", auth_client.get_session_details(user1)) + +link_status = asyncio.run(user1.link(connection="github")) + +github_token = user1.get_3rd_party_token("github") ``` --- From a44bbc1351f27286459ddc2e40c34d50173c50ad Mon Sep 17 00:00:00 2001 From: mustafadeel Date: Sat, 15 Feb 2025 15:53:03 -0500 Subject: [PATCH 06/11] fixed a typo in get_session using user instead of the Class User --- packages/auth0-ai/auth0_ai/ai_auth.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/auth0-ai/auth0_ai/ai_auth.py b/packages/auth0-ai/auth0_ai/ai_auth.py index 31b3565..7f51b7d 100644 --- a/packages/auth0-ai/auth0_ai/ai_auth.py +++ b/packages/auth0-ai/auth0_ai/ai_auth.py @@ -438,8 +438,8 @@ def get_session_details(self, user_id: str) -> dict[str, Any]: return {"user_id not found in session store"} def get_session(self, user: User) -> dict[str, Any]: - if User.user_id in self.session_store._get_stored_sessions(): - return (self._get_encrypted_session(User.user_id)) + if user.user_id in self.session_store._get_stored_sessions(): + return (self._get_encrypted_session(user.user_id)) else: return {"user_id not found in session store"} From 86c1bd92f30cea856244fb70bf1a1885746bffbb Mon Sep 17 00:00:00 2001 From: mustafadeel Date: Sat, 15 Feb 2025 15:55:20 -0500 Subject: [PATCH 07/11] updated readme fixed the call for get session to use implied user --- packages/auth0-ai/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/auth0-ai/README.md b/packages/auth0-ai/README.md index 0542163..4d221c4 100644 --- a/packages/auth0-ai/README.md +++ b/packages/auth0-ai/README.md @@ -61,7 +61,7 @@ async def link(user_id, connection): user1 = asyncio.run(login()) print("-" * 20) -print("USER DETAILS:", auth_client.get_session_details(user1)) +print("USER DETAILS:", auth_client.get_session(user1)) link_status = asyncio.run(user1.link(connection="github")) From c7540e6d3f5299924c86b4998e44ce25bab60083 Mon Sep 17 00:00:00 2001 From: mustafadeel Date: Mon, 17 Feb 2025 18:46:26 -0500 Subject: [PATCH 08/11] added logout and refactored cookie mgmt --- localhost.crt | 19 +++ localhost.key | 28 +++++ packages/auth0-ai/auth0_ai/ai_auth.py | 169 +++++++++++++++++++------- 3 files changed, 175 insertions(+), 41 deletions(-) create mode 100644 localhost.crt create mode 100644 localhost.key diff --git a/localhost.crt b/localhost.crt new file mode 100644 index 0000000..2705484 --- /dev/null +++ b/localhost.crt @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDDzCCAfegAwIBAgIUeheYrRpad428Bo7n3O4iSeDvl/wwDQYJKoZIhvcNAQEL +BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTI1MDIxNzE3MzExOFoXDTI1MDMx +OTE3MzExOFowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF +AAOCAQ8AMIIBCgKCAQEA14LBVxD5j3Rs0lqANYUoNHYPBGII1Z9Lh1SjEyapI857 +JUYZm9VwVLqZJmcv8ppQeyE6nQOI+EvmMC7zy2w0us4hmayH93XEjyBgFv9KNifG +csLH9FDwJmBK1pyIoSgWE82GjuDJOA/ociuFoi5TjNgL1F4DpS7zOMA+OJcOW9B3 +cWMT4F7oGLfnhrd+J1k/vsDGU047hivspFXKNBA1HpVyHC/nAIKUNZZkCaR0NjCI +zSGxQ19f2uG59J1jizlruk3y7UlvFskSGUocvRiNzPgsMpWplxwszqv0eEYV/Ber +HjaLnanEcF7zGEuGc2oMWziTmLPBQF4nKOlJGJsACwIDAQABo1kwVzAUBgNVHREE +DTALgglsb2NhbGhvc3QwCwYDVR0PBAQDAgeAMBMGA1UdJQQMMAoGCCsGAQUFBwMB +MB0GA1UdDgQWBBRI2ZH2Clz43NSxOxxXxm7tBaUGdzANBgkqhkiG9w0BAQsFAAOC +AQEAryDWVc6VXTn4rV/BXdxLMj//0aqpzmjY8OE+LElqEijvUPlM5RJCteK/bi87 +C8t02yDmaWnRuoblguusJayVDBSh3bxG6QjT+ijjTgOVDs30aOvTp3mxb34QYhox ++Mwg9IDrmqoSsXd9uNToBV4639IRmGSE/4yz/OTq+F1tlGpi8REhxIbmyN3ODExK +Uyp1wALTAR4koErqzruQ7vh7oWMNU3Mb6w6SizCe9K2oqMCQ8zCHSHKnP/eQn4Bo +6dzvX7HrTsJfbQKosMrDJQhi6QXZBE7DRzbY4vYpQ8E3HqvoAv9jBd1KrkXPMpN0 +ORR43t6HviCz9puPh0kMkdGRIg== +-----END CERTIFICATE----- diff --git a/localhost.key b/localhost.key new file mode 100644 index 0000000..3f70b28 --- /dev/null +++ b/localhost.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDXgsFXEPmPdGzS +WoA1hSg0dg8EYgjVn0uHVKMTJqkjznslRhmb1XBUupkmZy/ymlB7ITqdA4j4S+Yw +LvPLbDS6ziGZrIf3dcSPIGAW/0o2J8Zywsf0UPAmYErWnIihKBYTzYaO4Mk4D+hy +K4WiLlOM2AvUXgOlLvM4wD44lw5b0HdxYxPgXugYt+eGt34nWT++wMZTTjuGK+yk +Vco0EDUelXIcL+cAgpQ1lmQJpHQ2MIjNIbFDX1/a4bn0nWOLOWu6TfLtSW8WyRIZ +Shy9GI3M+CwylamXHCzOq/R4RhX8F6seNoudqcRwXvMYS4ZzagxbOJOYs8FAXico +6UkYmwALAgMBAAECggEACf+SebgbY+jqXY3+Ub1WQqzRgIoNz99ekS4/jJVoFnWv +Z+jLKlwqJHwtu8bgxhgbsMK3Ze5yjdZznPuoquDfx2Tl0SvceQIZNuyxGJAKgN2y +isN3pGGW6qjf//nuKs/hylRoMDvEihnO1nEnd4E/thKV5engsGqvtQvSNyzm6SJ6 +d/u5uFK7VQrWAkJ1+xCSTPPWi77CiMBh/qMY2Y6Ph0wYexQdNmS8np1cN7c5+N1b +evLSQAIiNUhXmVYY31z6oWHq+CYM/CRyWgDfQHFlUoBisIizzMptvDcvFi+7Rp22 +qp1Lm6KNOPG7/nq/zDNSBfd1FqlaR/TcQDMtVUKaCQKBgQD7S+/NovCBnkhiEciU +GHGIvpfoyXO/UfiGKcf4WZCKfqlz9RNvzmuekRK1GAJzRnsO6HBeNwPasTHFIBJd +WZIJ4nJS681R/TTLfIPveKuo99F1iB2mZHDX9dkQS3ZlK7OuudVuE0xTcUoNAD0P +9zWcSoBrLrEToHYldtbq+FW8swKBgQDbi1qhNduMV/wq5SuEAntSraKdKcJho47R +MkNu3grg5sNNaMbQJhiO5KHBdfPefgQAqp0ZuYSMPg8GPnIrKN4nFegkBUWgjoi9 +zAgfmdSaExo8gNe3S5J9wSEVpskVnDtLjS0FYcTOCLnNneiim8TGLKHXUpvWOTdt +wcfPCgGLSQKBgBFlk2dgBVhj1cz8QC+IdauqziductXm3daj49UcljYQSLjfWYYe ++zJSBsKEs/64/WHt04GiO2ETbUehTcQqpEKM668z5dXsOpBvwU59wxyCc3y4fJz9 +TRaWTX2kS8D7QogxE0Z4jYslR6QYxSFq0spMGhHRfK7IKAW18XD42i6jAoGAMAZL +zPf7DrgwcTGwUzA3yd4xtC9uVe1xUFGubpIjzw6rqkNBOkcbGCbrO2aR8hmexoaL +1xS96e+pWbRPRSGrduFT5o1Ard6ACwSWwlLkLs/+7T1B8taVNO0KT7IsSo3iaqR3 +NLYuVuORwWjJesiYQsGApZlsfXAGr/uzuZZ2wAECgYEAgB9QWsVzWHhbJ1YNK9qE +igWV4src1lRcwq3pN5iYLT6NXMvLZjPHOtEz5PRO4M85SwDlOG+82BWmDMNgFwoq +W+BC+xjRbqywT/7Acj0zS1GUm2gPqPt606b/H8L4WjENJO23C1iEF/brOb6GJKxy +WBfnsVRqocHJ/hHuLIdBvMQ= +-----END PRIVATE KEY----- diff --git a/packages/auth0-ai/auth0_ai/ai_auth.py b/packages/auth0-ai/auth0_ai/ai_auth.py index 7f51b7d..ec0f860 100644 --- a/packages/auth0-ai/auth0_ai/ai_auth.py +++ b/packages/auth0-ai/auth0_ai/ai_auth.py @@ -4,6 +4,7 @@ from auth0.authentication.base import AuthenticationBase from auth0.authentication import GetToken +from auth0.authentication import RevokeToken from auth0.authentication.async_token_verifier import AsyncAsymmetricSignatureVerifier from auth0.authentication.pushed_authorization_requests import PushedAuthorizationRequests @@ -143,13 +144,14 @@ async def manage_callback(request: Request, response: Response): if auth0_tokens: - cookie_data = await self._set_encrypted_session(auth0_tokens, state=received_state) + cookie_session_data = await self._set_encrypted_session(auth0_tokens, state=received_state) response.set_cookie( - key="session", - value=cookie_data, + key="sessionData", + value=cookie_session_data, + path="/auth", httponly=True, # Prevent JavaScript access - # secure=True, # Send only over HTTPS + secure=True, # Send only over HTTPS samesite="Lax", # Protect against CSRF # set expiry based on access token expiry max_age=auth0_tokens["expires_in"], @@ -167,21 +169,33 @@ async def manage_callback(request: Request, response: Response): # Register the login route @self.app.get("/auth/login") async def manage_login(request: Request, response: Response): - # if scope is None: - scope = "openid profile email" - connection = "Username-Password-Authentication" - state = self._generate_state() + # check cookie for existing session + auth_cookie = request.cookies.get("sessionData") + if auth_cookie: + decoded_data = jwt.decode( + auth_cookie, self.secret_key, algorithms=["HS256"]) + # Session cookie exists, do something with it + # ... + return {"session": decoded_data} + else: + # No session cookie, redirect to Auth0 + query_params = urllib.parse.parse_qs(request.query_params) + scope = query_params.get("scope", ["openid profile email"])[0] + connection = query_params.get("connection", ["Username-Password-Authentication"])[0] + return_to = query_params.get("return_to", ["/"])[0] - auth_url = self.get_authorize_url( - state=state, connection=connection, scope=scope, **kwargs) + state = self._generate_state() - return RedirectResponse(url=auth_url, status_code=302) + auth_url = self.get_authorize_url( + state=state, connection=connection, scope=scope, redirect_uri=return_to) + + return RedirectResponse(url=auth_url, status_code=302) @self.app.get("/auth/get_user") async def get_user(request: Request): """Reads the session cookie and extracts user info.""" - auth_cookie = request.cookies.get("session") + auth_cookie = request.cookies.get("sessionData") if not auth_cookie: raise HTTPException( @@ -193,7 +207,7 @@ async def get_user(request: Request): auth_cookie, self.secret_key, algorithms=["HS256"]) # Extract the user ID (sub) from the decoded JWT - user_id = decoded_data["user_id"] + user_id = decoded_data.get('user').get('sub') if not user_id: raise HTTPException( @@ -208,20 +222,95 @@ async def get_user(request: Request): raise HTTPException( status_code=401, detail="Invalid session cookie.") + @self.app.get("/auth/logout") + async def manage_logout(request: Request, response: Response): + """Reads the session cookie and extracts user info.""" + auth_cookie = request.cookies.get("sessionData") + + if not auth_cookie: + raise HTTPException( + status_code=401, detail="Missing session cookie.") + + try: + # Decode the JWT stored in the session cookie + decoded_data = jwt.decode( + auth_cookie, self.secret_key, algorithms=["HS256"]) + + # Extract the user ID (sub) from the decoded JWT + user_id = decoded_data.get('user').get('sub') + + if not user_id: + raise HTTPException( + status_code=400, detail="Invalid session cookie: Missing 'sub' claim.") + + self.get(url=f"https://{self.domain}/v2/logout") + revoke_rt = RevokeToken(self.domain, self.client_id, self.client_secret) + revoke_rt.revoke_refresh_token(token=decoded_data.get("tokens").get("refresh_token")) + + response.delete_cookie(key="sessionData",path="/auth") + self.session_store._delete_stored_session(user_id) + + # MODIFY RESPONSE to ensure it returns properly + response.body = b'{"message": "logout successful"}' + response.status_code = 200 + response.media_type = "application/json" + + return response + + except jwt.ExpiredSignatureError: + raise HTTPException( + status_code=401, detail="Session cookie has expired.") + except jwt.InvalidTokenError: + raise HTTPException( + status_code=401, detail="Invalid session cookie.") + + # Start middleware server in a separate thread self.host = urllib.parse.urlparse(self.redirect_uri).hostname self.port = urllib.parse.urlparse(self.redirect_uri).port + self.protocol = urllib.parse.urlparse(self.redirect_uri).scheme self._start_server() + def _is_valid_file(self,file_path): + """Check if the file exists and is accessible.""" + return os.path.isfile(file_path) and os.access(file_path, os.R_OK) + def _start_server(self): """Runs FastAPI as the middleware inside a separate thread.""" - server_thread = threading.Thread( - target=uvicorn.run, - args=(self.app,), - kwargs={"host": self.host, "port": self.port, "log_level": "info"}, - daemon=True # Daemon mode so it exits when the main thread exits - ) - server_thread.start() + if (self.protocol == "https"): + + ssl_keyfile = os.getenv("AUTH0_SSL_KEYFILE") + ssl_certfile = os.getenv("AUTH0_SSL_CERTFILE") + + if not self._is_valid_file(ssl_keyfile) or not self._is_valid_file(ssl_certfile): + raise ValueError( + "AUTH0_SSL_KEYFILE and AUTH0_SSL_CERTFILE environment variables must be set with valid file paths for HTTPS.") + + server_thread = threading.Thread( + target=uvicorn.run, + args=(self.app,), + kwargs={ + "host": self.host, + "port": self.port, + "ssl_keyfile": ssl_keyfile, # Path to private key + "ssl_certfile": ssl_certfile, # Path to certificate + "log_level": "error"}, + daemon=True # Daemon mode so it exits when the main thread exits + ) + else: + server_thread = threading.Thread( + target=uvicorn.run, + args=(self.app,), + kwargs={ + "host": self.host, + "port": self.port, + "log_level": "info"}, + daemon=True # Daemon mode so it exits when the main thread exits + ) + try: + server_thread.start() + except Exception as e: + print(f"Error starting middleware server: {str(e)}") def _generate_state(self) -> str: """Generate a secure random state and store it for validation.""" @@ -229,15 +318,16 @@ def _generate_state(self) -> str: # Store it temporarily and flag it as false as we havent received it back as yet self.state_store[state] = {"is_competed": False} return state - + def _get_token_set(self, token_data: str, existing_refresh_token: str | None = None) -> dict: - + """Extracts the access token, scope, refresh token, and expiry time from the token_data.""" token_data = { - "access_token": token_data.get("access_token"), "expires_at": {"epoch": int(time.time())+token_data["expires_in"]}, - "refresh_token": token_data.get("refresh_token", existing_refresh_token), - "id_token": token_data.get("id_token"), + "access_token": token_data.get("access_token"), "scope": token_data.get("scope"), + "expires_at": {"epoch": int(time.time())+token_data["expires_in"]}, + "refresh_token": token_data.get("refresh_token", existing_refresh_token), } + return token_data def _get_linked_details(self, token_data: dict, existing_linked_connections: list[str] | None = None) -> list[str]: @@ -260,6 +350,7 @@ def _get_linked_details(self, token_data: dict, existing_linked_connections: lis async def _set_encrypted_session(self, token_data, state: str | None = None) -> str: id_token = token_data.get("id_token", "") + decoded_id_token = {} if id_token: try: decoded_id_token = await self.token_verifier.verify_signature(id_token) @@ -283,13 +374,15 @@ async def _set_encrypted_session(self, token_data, state: str | None = None) -> if existing_encrypted_session: # found existing session, check if there is a refresh token to keep existing_session = self._get_encrypted_session(user_id) - existing_refresh_token = existing_session.get( - "tokens").get("refresh_token", None) - existing_linked_connections = existing_session.get( - "linked_connections") + if existing_session: + existing_refresh_token = existing_session.get( + "tokens").get("refresh_token", None) + existing_linked_connections = existing_session.get( + "linked_connections", None) session_data = {} - session_data = {"user_id": user_id, + session_data = {"user": decoded_id_token, + "id_token": {"id_token": id_token,"id_token_expiry": decoded_id_token.get("exp")}, "tokens": self._get_token_set(token_data, existing_refresh_token), "linked_connections": self._get_linked_details(token_data, existing_linked_connections) } @@ -318,7 +411,7 @@ def _get_encrypted_session(self, user_id): encrypted_session, self.secret_key, algorithms=["HS256"]) # Extract the user ID (sub) from the decoded JWT - user_id = decoded_data["user_id"] + user_id = decoded_data.get('user').get('sub') token_expiry = decoded_data.get("tokens", {}).get( "expires_at", {}).get("epoch") @@ -340,7 +433,7 @@ def _get_encrypted_session(self, user_id): return {"Invalid session."} def _update_encrypted_session(self, user_id, refresh_token): - token_manager = GetToken() + token_manager = GetToken(self.domain, self.client_id, self.client_secret) updated_tokens = token_manager.refresh_token( refresh_token=refresh_token) @@ -355,7 +448,6 @@ def get_authorize_url( additional_scopes: str | None = None, **kwargs, ) -> str: - base_url = ( f"https://{self.domain}/authorize?" f"response_type=code&" @@ -431,15 +523,10 @@ def get_upstream_token( return (x) - def get_session_details(self, user_id: str) -> dict[str, Any]: - if user_id in self.session_store._get_stored_sessions(): - return (self._get_encrypted_session(user_id)) - else: - return {"user_id not found in session store"} - def get_session(self, user: User) -> dict[str, Any]: if user.user_id in self.session_store._get_stored_sessions(): - return (self._get_encrypted_session(user.user_id)) + session = self._get_encrypted_session(user.user_id) + return (session.get("user")) else: return {"user_id not found in session store"} @@ -596,8 +683,8 @@ class User(AIAuth): def __init__(self, parent, user_id: str): self.parent = parent - self.user_id = user_id self.client = parent.client + self.user_id = user_id async def link(self, connection: str | None = None, scope: str | None = None, **kwargs) -> str: return await self.parent.link(connection=connection, primary_user_id=self.user_id, id_token=self.get_id_token(), scope=scope, **kwargs) From ab7931aa72772718595c6ac202de86bc5a8f93aa Mon Sep 17 00:00:00 2001 From: mustafadeel Date: Mon, 17 Feb 2025 22:25:08 -0500 Subject: [PATCH 09/11] added return_to to /auth/login --- packages/auth0-ai/auth0_ai/ai_auth.py | 52 +++++++++++-------- packages/auth0-ai/auth0_ai/session_storage.py | 1 + 2 files changed, 32 insertions(+), 21 deletions(-) diff --git a/packages/auth0-ai/auth0_ai/ai_auth.py b/packages/auth0-ai/auth0_ai/ai_auth.py index ec0f860..e74f1f3 100644 --- a/packages/auth0-ai/auth0_ai/ai_auth.py +++ b/packages/auth0-ai/auth0_ai/ai_auth.py @@ -146,6 +146,19 @@ async def manage_callback(request: Request, response: Response): cookie_session_data = await self._set_encrypted_session(auth0_tokens, state=received_state) + user_id = self.state_store[received_state].get( + "user_id", "failed") + return_to = self.state_store[received_state].get("return_to", None) + self.state_store[received_state] = { + "user_id": user_id, "is_completed": True} + + if return_to: + response = RedirectResponse(url=return_to, status_code=302) + + else: + response.body = b'{"message": "login successful"}' + response.status_code = 200 + response.set_cookie( key="sessionData", value=cookie_session_data, @@ -156,19 +169,16 @@ async def manage_callback(request: Request, response: Response): # set expiry based on access token expiry max_age=auth0_tokens["expires_in"], ) - - user_id = self.state_store[received_state].get( - "user_id", "failed") - self.state_store[received_state] = { - "user_id": user_id, "is_completed": True} - - return {"message": "successul. you can now close this window"} - - # Register the callback route + + return response + else: + raise HTTPException( + status_code=400, detail="Failed to exchange code for tokens.") # Register the login route @self.app.get("/auth/login") - async def manage_login(request: Request, response: Response): + async def manage_login(request: Request, response: Response, + return_to: str | None = None, scope: str | None = None, connection: str | None = None): # check cookie for existing session auth_cookie = request.cookies.get("sessionData") @@ -180,15 +190,13 @@ async def manage_login(request: Request, response: Response): return {"session": decoded_data} else: # No session cookie, redirect to Auth0 - query_params = urllib.parse.parse_qs(request.query_params) - scope = query_params.get("scope", ["openid profile email"])[0] - connection = query_params.get("connection", ["Username-Password-Authentication"])[0] - return_to = query_params.get("return_to", ["/"])[0] + _scope = scope or "openid profile email" + _connection = connection or "Username-Password-Authentication" - state = self._generate_state() + state = self._generate_state(return_to=return_to) auth_url = self.get_authorize_url( - state=state, connection=connection, scope=scope, redirect_uri=return_to) + state=state, connection=_connection, scope=_scope) return RedirectResponse(url=auth_url, status_code=302) @@ -244,8 +252,10 @@ async def manage_logout(request: Request, response: Response): status_code=400, detail="Invalid session cookie: Missing 'sub' claim.") self.get(url=f"https://{self.domain}/v2/logout") - revoke_rt = RevokeToken(self.domain, self.client_id, self.client_secret) - revoke_rt.revoke_refresh_token(token=decoded_data.get("tokens").get("refresh_token")) + rt = token=decoded_data.get("tokens").get("refresh_token", None) + if rt: + rt_manager = RevokeToken(self.domain, self.client_id, self.client_secret) + rt_manager.revoke_refresh_token(token=decoded_data.get("tokens").get("refresh_token")) response.delete_cookie(key="sessionData",path="/auth") self.session_store._delete_stored_session(user_id) @@ -312,11 +322,11 @@ def _start_server(self): except Exception as e: print(f"Error starting middleware server: {str(e)}") - def _generate_state(self) -> str: + def _generate_state(self, return_to: str | None = None) -> str: """Generate a secure random state and store it for validation.""" state = secrets.token_urlsafe(16) # Generate a random state # Store it temporarily and flag it as false as we havent received it back as yet - self.state_store[state] = {"is_competed": False} + self.state_store[state] = {"is_competed": False, "return_to": return_to} return state def _get_token_set(self, token_data: str, existing_refresh_token: str | None = None) -> dict: @@ -394,7 +404,7 @@ async def _set_encrypted_session(self, token_data, state: str | None = None) -> self.session_store._set_stored_session( user_id=user_id, encrypted_session_data=encrypted_session_data) - self.state_store[state] = {"user_id": user_id} + self.state_store[state]["user_id"] = user_id # print("Session created/updated for:",user_id) return encrypted_session_data diff --git a/packages/auth0-ai/auth0_ai/session_storage.py b/packages/auth0-ai/auth0_ai/session_storage.py index 8b94c4f..408b59d 100644 --- a/packages/auth0-ai/auth0_ai/session_storage.py +++ b/packages/auth0-ai/auth0_ai/session_storage.py @@ -56,5 +56,6 @@ def _delete_stored_session(self, user_id): with shelve.open(".sessions_cache") as sessions: if user_id in sessions: del sessions[user_id] + sessions.sync() else: self.del_session() From 709fa059645855dd6864dd7f0acac4df68155b95 Mon Sep 17 00:00:00 2001 From: Snehil Kishore Date: Wed, 19 Feb 2025 11:16:31 +0530 Subject: [PATCH 10/11] Modular Structure --- packages/auth0-ai/auth0_ai/auth/__init__.py | 5 + .../auth0-ai/auth0_ai/auth/auth_client.py | 189 ++++++++++++++++++ packages/auth0-ai/auth0_ai/auth/base.py | 47 +++++ packages/auth0-ai/auth0_ai/auth/user.py | 113 +++++++++++ packages/auth0-ai/auth0_ai/server/__init__.py | 8 + .../auth0-ai/auth0_ai/server/auth_server.py | 47 +++++ packages/auth0-ai/auth0_ai/server/routes.py | 105 ++++++++++ .../auth0_ai/session_module/__init__.py | 12 ++ .../auth0_ai/session_module/manager.py | 181 +++++++++++++++++ .../session_module/storage/__init__.py | 7 + .../session_module/storage/base_store.py | 45 +++++ .../session_module/storage/local_store.py | 51 +++++ packages/auth0-ai/auth0_ai/state/__init__.py | 9 + .../auth0-ai/auth0_ai/state/base_state.py | 38 ++++ .../auth0-ai/auth0_ai/state/link_state.py | 65 ++++++ .../auth0-ai/auth0_ai/state/login_state.py | 63 ++++++ .../auth0_ai/token_module/__init__.py | 6 + .../auth0-ai/auth0_ai/token_module/manager.py | 183 +++++++++++++++++ packages/auth0-ai/auth0_ai/utils/__init__.py | 7 + .../auth0-ai/auth0_ai/utils/url_builder.py | 117 +++++++++++ 20 files changed, 1298 insertions(+) create mode 100644 packages/auth0-ai/auth0_ai/auth/__init__.py create mode 100644 packages/auth0-ai/auth0_ai/auth/auth_client.py create mode 100644 packages/auth0-ai/auth0_ai/auth/base.py create mode 100644 packages/auth0-ai/auth0_ai/auth/user.py create mode 100644 packages/auth0-ai/auth0_ai/server/__init__.py create mode 100644 packages/auth0-ai/auth0_ai/server/auth_server.py create mode 100644 packages/auth0-ai/auth0_ai/server/routes.py create mode 100644 packages/auth0-ai/auth0_ai/session_module/__init__.py create mode 100644 packages/auth0-ai/auth0_ai/session_module/manager.py create mode 100644 packages/auth0-ai/auth0_ai/session_module/storage/__init__.py create mode 100644 packages/auth0-ai/auth0_ai/session_module/storage/base_store.py create mode 100644 packages/auth0-ai/auth0_ai/session_module/storage/local_store.py create mode 100644 packages/auth0-ai/auth0_ai/state/__init__.py create mode 100644 packages/auth0-ai/auth0_ai/state/base_state.py create mode 100644 packages/auth0-ai/auth0_ai/state/link_state.py create mode 100644 packages/auth0-ai/auth0_ai/state/login_state.py create mode 100644 packages/auth0-ai/auth0_ai/token_module/__init__.py create mode 100644 packages/auth0-ai/auth0_ai/token_module/manager.py create mode 100644 packages/auth0-ai/auth0_ai/utils/__init__.py create mode 100644 packages/auth0-ai/auth0_ai/utils/url_builder.py diff --git a/packages/auth0-ai/auth0_ai/auth/__init__.py b/packages/auth0-ai/auth0_ai/auth/__init__.py new file mode 100644 index 0000000..1e15861 --- /dev/null +++ b/packages/auth0-ai/auth0_ai/auth/__init__.py @@ -0,0 +1,5 @@ +from .auth_client import AIAuth, User +from .base import BaseAuth +from .user import User + +__all__ = ["AIAuth", "BaseAuth", "User"] diff --git a/packages/auth0-ai/auth0_ai/auth/auth_client.py b/packages/auth0-ai/auth0_ai/auth/auth_client.py new file mode 100644 index 0000000..e02d377 --- /dev/null +++ b/packages/auth0-ai/auth0_ai/auth/auth_client.py @@ -0,0 +1,189 @@ +from __future__ import annotations +from typing import Any +import webbrowser +import secrets +import json +from typing import Any, Dict + +from auth0.authentication.async_token_verifier import AsyncAsymmetricSignatureVerifier +from auth0.authentication.pushed_authorization_requests import PushedAuthorizationRequests + +from .base import BaseAuth +from .user import User +from server.auth_server import AuthServer +from token_module.manager import TokenManager +from session_module.manager import SessionManager +from state.login_state import LoginState +from state.link_state import LinkState +from utils.url_builder import URLBuilder + +class AIAuth(BaseAuth): + """Main authentication class that orchestrates the auth flow""" + def __init__( + self, + domain: str | None = None, + client_id: str | None = None, + client_secret: str | None = None, + redirect_uri: str | None = None, + secret_key: str | None = None, + *args, **kwargs): + """Initialize AIAuth with all necessary components""" + super().__init__( + domain=domain, + client_id=client_id, + client_secret=client_secret, + redirect_uri=redirect_uri, + secret_key=secret_key, + *args, **kwargs + ) + # Initialize token verifier + jwk_url = f"https://{self.domain}/.well-known/jwks.json" + self.token_verifier = AsyncAsymmetricSignatureVerifier(jwks_url=jwk_url) + # Initialize components + self.state_store: Dict[str, Dict[str, Any]] = {} + self.session_manager = SessionManager(self) + self.token_manager = TokenManager(self) + self.url_builder = URLBuilder(self) + # Initialize server + self.server= AuthServer(self) + + def _generate_state(self) -> str: + """Generate a secure random state and store it""" + state = secrets.token_urlsafe(16) + self.state_store[state] = {"is_completed": False} + return state + async def interactive_login( + self, + connection: str | None = None, + scope: str | None = None, + **kwargs + ) -> User: + """ + Handle interactive login flow. + Args: + connection: Optional connection to use + scope: OAuth scope (default: "openid profile email") + **kwargs: Additional parameters for authorization + Returns: + User instance if successful, error string if failed + """ + if scope is None: + scope = "openid profile email" + # Generate state and create login state tracker + state = self._generate_state() + login_state = LoginState(self.state_store, state) + # Generate authorization URL + auth_url = self.url_builder.get_authorize_url( + state=state, + connection=connection, + scope=scope, + **kwargs + ) + # Open browser for authentication + try: + webbrowser.open(auth_url) + except webbrowser.Error: + print(f"Please navigate here: {auth_url}") + # Wait for authentication completion + user_id = await login_state.wait_for_completion() + if not user_id: + return "login failed" + return User(self, user_id=user_id) + + async def link( + self, + primary_user_id: str, + connection: str, + id_token: str, + scope: str | None = None, + **kwargs + ) -> Dict[str, Any]: + """ + Handle account linking flow. + Args: + primary_user_id: ID of the user initiating the link + connection: Connection to link + id_token: ID token of the primary user + scope: OAuth scope + **kwargs: Additional parameters for authorization + Returns: + Dict containing link status and user information + """ + state = self._generate_state() + link_state = LinkState(self.state_store, state) + link_state.set_user(primary_user_id) + try: + # Create PAR request + par_client = PushedAuthorizationRequests( + self.domain, + self.client_id, + self.client_secret + ) + par_response = await par_client.pushed_authorization_request( + response_type="code", + nonce=kwargs.get("nonce", "mynonce"), + redirect_uri=self.redirect_uri, + audience=kwargs.get("audience", "my-account"), + state=state, + authorization_details=json.dumps([ + {"type": "link_account", "requested_connection": connection} + ]), + scope=scope or "openid profile", + id_token_hint=id_token, + **kwargs + ) + request_uri = par_response.get('request_uri') + if not request_uri: + raise ValueError("Failed to get request_uri from PAR response") + # Generate authorization URL with PAR + auth_url = self.url_builder.get_authorize_par_url( + state=state, + request_uri=request_uri + ) + # Open browser for linking + try: + webbrowser.open(auth_url) + except webbrowser.Error: + print(f"Please navigate here: {auth_url}") + # Wait for linking completion + user_id = await link_state.wait_for_completion() + return { + "is_successful": bool(user_id), + "user_id": user_id or primary_user_id + } + except Exception as error: + print(f"Error during linking: {error}") + return { + "is_successful": False, + "user_id": primary_user_id, + } + + def get_session_details(self, user_id: str) -> Dict[str, Any]: + """Get session details for a user""" + return self.session_manager.get_session_details(user_id) + + def get_session(self, user: User) -> Dict[str, Any]: + """Get session for a user object""" + return self.session_manager.get_session(user) + + def get_upstream_token( + self, + connection: str, + refresh_token: str, + additional_scopes: str | None = None + ) -> Dict[str, Any]: + """Get token for federated connection""" + return self.token_manager.get_upstream_token( + connection=connection, + refresh_token=refresh_token, + additional_scopes=additional_scopes + ) + + + + + + + + + diff --git a/packages/auth0-ai/auth0_ai/auth/base.py b/packages/auth0-ai/auth0_ai/auth/base.py new file mode 100644 index 0000000..f7cfd81 --- /dev/null +++ b/packages/auth0-ai/auth0_ai/auth/base.py @@ -0,0 +1,47 @@ +from __future__ import annotations +import os +from auth0.authentication.base import AuthenticationBase + +class BaseAuth(AuthenticationBase): + """Base authentication class with core properties and validations""" + # Define required config fields and their environment variable names + REQUIRED_CONFIGS = { + 'domain': 'AUTH0_DOMAIN', + 'client_id': 'AUTH0_CLIENT_ID', + 'client_secret': 'AUTH0_CLIENT_SECRET', + 'redirect_uri': 'AUTH0_REDIRECT_URI', + 'secret_key': 'AUTH0_SECRET_KEY' + } + + def __init__( + self, + domain: str | None = None, + client_id: str | None = None, + client_secret: str | None = None, + redirect_uri: str | None = None, + secret_key: str | None = None, + *args, **kwargs): + + # Initialize all config properties + for field, env_var in self.REQUIRED_CONFIGS.items(): + value = locals().get(field) or os.environ.get(env_var) + setattr(self, f'_{field}', None) # Initialize private attribute + setattr(self.__class__, field, property( # Create property + fget=lambda self, f=field: getattr(self, f'_{f}'), + fset=lambda self, value, f=field: self._validate_and_set(f, value) + )) + setattr(self, field, value) # Set the value using property setter + + super().__init__( + domain=self.domain, + client_id=self.client_id, + client_secret=self.client_secret, + *args, **kwargs + ) + + def _validate_and_set(self, field: str, value: str | None) -> None: + """Validate and set a configuration value""" + if not value: + raise ValueError( + f"{field} cannot be empty. You can also set {self.REQUIRED_CONFIGS[field]} value in .env file") + setattr(self, f'_{field}', value) \ No newline at end of file diff --git a/packages/auth0-ai/auth0_ai/auth/user.py b/packages/auth0-ai/auth0_ai/auth/user.py new file mode 100644 index 0000000..429449a --- /dev/null +++ b/packages/auth0-ai/auth0_ai/auth/user.py @@ -0,0 +1,113 @@ +from __future__ import annotations +from typing import Any, Dict + +class User: + """ + Represents an authenticated user with token management capabilities. + """ + def __init__(self, auth_client, user_id: str): + """ + Initialize a User instance. + Args: + auth_client: The parent AIAuth instance + user_id: The unique identifier for the user + """ + self._auth_client = auth_client + self._user_id = user_id + @property + def user_id(self) -> str: + """Get the user's ID""" + return self._user_id + async def link(self, connection: str, scope: str | None = None, **kwargs) -> Dict[str, Any]: + """ + Link another authentication provider to this user account. + Args: + connection: The name of the connection to link (e.g., 'github', 'google') + scope: OAuth scope for the connection + **kwargs: Additional parameters for the linking process + Returns: + Dict containing link status and user information + """ + return await self._auth_client.link( + primary_user_id=self.user_id, + connection=connection, + id_token=self.get_id_token(), + scope=scope, + **kwargs + ) + + def get_id_token(self) -> str: + """Get the user's ID token""" + return self._auth_client.token_manager.get_id_token(self.user_id) + + def get_access_token(self) -> str: + """Get the user's access token""" + return self._auth_client.token_manager.get_access_token(self.user_id) + + def get_refresh_token(self) -> str: + """Get the user's refresh token""" + return self._auth_client.token_manager.get_refresh_token(self.user_id) + + def get_3rd_party_token(self, connection: str) -> Dict[str, Any]: + """ + Get access token for a linked third-party provider. + Args: + connection: The name of the third-party connection (e.g., 'github') + Returns: + Dict containing the third-party access token and related information + """ + refresh_token = self.get_refresh_token() + return self._auth_client.get_upstream_token(connection, refresh_token) + + async def get_profile(self) -> Dict[str, Any]: + """ + Get the user's profile information. + Returns: + Dict containing user profile data + """ + access_token = self.get_access_token() + return await self._auth_client.token_manager.get_userinfo(access_token) + + async def get_token_info(self) -> Dict[str, Any]: + """ + Get detailed information about the user's tokens. + Returns: + Dict containing token information + """ + id_token = self.get_id_token() + access_token = self.get_access_token() + return await self._auth_client.token_manager.get_tokeninfo(id_token, access_token) + + def is_token_valid(self) -> bool: + """ + Check if the user's tokens are still valid. + Returns: + bool indicating token validity + """ + return self._auth_client.token_manager.validate_tokens(self.user_id) + + async def refresh_tokens(self) -> bool: + """ + Refresh the user's tokens if they're expired. + Returns: + bool indicating if refresh was successful + """ + refresh_token = self.get_refresh_token() + if not refresh_token: + return False + return await self._auth_client.token_manager.refresh_tokens(self.user_id, refresh_token) + + def get_session(self) -> Dict[str, Any]: + """ + Get the current session information for the user. + Returns: + Dict containing session information + """ + return self._auth_client.get_session(self) + + + + + + + diff --git a/packages/auth0-ai/auth0_ai/server/__init__.py b/packages/auth0-ai/auth0_ai/server/__init__.py new file mode 100644 index 0000000..4dee87a --- /dev/null +++ b/packages/auth0-ai/auth0_ai/server/__init__.py @@ -0,0 +1,8 @@ +""" +Auth0 AI Server Module +Internal module for handling OAuth callback server and routes. +""" +from .auth_server import AuthServer +from .routes import setup_routes + +__all__ = ["AuthServer","setup_routes"] diff --git a/packages/auth0-ai/auth0_ai/server/auth_server.py b/packages/auth0-ai/auth0_ai/server/auth_server.py new file mode 100644 index 0000000..15e8d39 --- /dev/null +++ b/packages/auth0-ai/auth0_ai/server/auth_server.py @@ -0,0 +1,47 @@ +from __future__ import annotations +import threading +import urllib.parse +from typing import Any + +import uvicorn +from fastapi import FastAPI + +from .routes import setup_routes + +class AuthServer: + """ + FastAPI server handling Auth0 callbacks and authentication routes. + """ + + def __init__(self, auth_client: Any): + """ + Initialize the authentication server. + + Args: + auth_client: The parent AIAuth instance + """ + self.auth_client = auth_client + self.app = FastAPI() + + # Parse redirect URI for server config + parsed_uri = urllib.parse.urlparse(auth_client.redirect_uri) + self.host = parsed_uri.hostname + self.port = parsed_uri.port + + # Setup routes with dependencies + setup_routes(self.app, auth_client) + self.start() + + def start(self) -> None: + """Start the FastAPI server in a daemon thread.""" + server_thread = threading.Thread( + target=uvicorn.run, + args=(self.app,), + kwargs={ + "host": self.host or "localhost", + "port": self.port or "3000", + "log_level": "info" + }, + daemon=False + ) + server_thread.start() diff --git a/packages/auth0-ai/auth0_ai/server/routes.py b/packages/auth0-ai/auth0_ai/server/routes.py new file mode 100644 index 0000000..549fd18 --- /dev/null +++ b/packages/auth0-ai/auth0_ai/server/routes.py @@ -0,0 +1,105 @@ +from __future__ import annotations +from typing import Any + +from fastapi import FastAPI, Request, Response, HTTPException +from fastapi.responses import JSONResponse, RedirectResponse +import jwt + + +def setup_routes(app: FastAPI, auth_client: Any) -> None: + """Set up all routes for the authentication server.""" + + @app.get("/auth/callback") + async def manage_callback(request: Request, response: Response): + """Parses and validates callback URL query parameters.""" + query_params = request.query_params + required_keys = {"code", "state"} + + if query_params.get("error"): + error_description = query_params.get("error_description", "Unknown error occurred.") + if query_params.get("state"): + del auth_client.state_store[query_params.get("state")] + raise HTTPException(status_code=400, detail=error_description) + + if not required_keys.issubset(query_params.keys()): + raise HTTPException(status_code=400, detail="Missing required query parameters.") + + received_state = query_params["state"] + + # Validate state to prevent CSRF attacks + if received_state not in auth_client.state_store: + raise HTTPException(status_code=400, detail="Invalid or missing state parameter.") + + # Extract code value from query string + received_code = query_params["code"] + + auth0_tokens = auth_client.token_manager.exchange_code_for_tokens(received_code) + + if auth0_tokens: + cookie_data = await auth_client.session_manager.set_encrypted_session(auth0_tokens, state=received_state) + + response.set_cookie( + key="session", + value=cookie_data, + httponly=True, # Prevent JavaScript access + # secure=True, # Send only over HTTPS + samesite="Lax", # Protect against CSRF + # set expiry based on access token expiry + max_age=auth0_tokens["expires_in"], + ) + + user_id = auth_client.state_store[received_state].get("user_id", "failed") + auth_client.state_store[received_state] = {"user_id": user_id, "is_completed": True} + + return {"message": "successful. you can now close this window"} + + @app.get("/auth/login") + async def manage_login(request: Request, response: Response): + """Handle login initiation.""" + # if scope is None: # Original comment preserved + scope = "openid profile email" + connection = "Username-Password-Authentication" + + state = auth_client._generate_state() + + auth_url = auth_client.url_builder.get_authorize_url( + state=state, + connection=connection, + scope=scope + ) + + return RedirectResponse(url=auth_url, status_code=302) + + @app.get("/auth/get_user") + async def get_user(request: Request): + """Reads the session cookie and extracts user info.""" + auth_cookie = request.cookies.get("session") + + if not auth_cookie: + raise HTTPException(status_code=401, detail="Missing session cookie.") + + try: + # Decode the JWT stored in the session cookie + decoded_data = jwt.decode(auth_cookie, auth_client.secret_key, algorithms=["HS256"]) + + # Extract the user ID (sub) from the decoded JWT + user_id = decoded_data["user_id"] + + if not user_id: + raise HTTPException(status_code=400, detail="Invalid session cookie: Missing 'sub' claim.") + + return JSONResponse(content=decoded_data) + + except jwt.ExpiredSignatureError: + raise HTTPException(status_code=401, detail="Session cookie has expired.") + except jwt.InvalidTokenError: + raise HTTPException(status_code=401, detail="Invalid session cookie.") + + + + + + + + + diff --git a/packages/auth0-ai/auth0_ai/session_module/__init__.py b/packages/auth0-ai/auth0_ai/session_module/__init__.py new file mode 100644 index 0000000..8e41a90 --- /dev/null +++ b/packages/auth0-ai/auth0_ai/session_module/__init__.py @@ -0,0 +1,12 @@ +""" +Session Management Module +Provides session handling, storage, and encryption capabilities. +""" +from .manager import SessionManager +from .storage.base_store import BaseStore +from .storage.local_store import LocalStore +__all__ = [ + "SessionManager", + "BaseStore", + "LocalStore" +] diff --git a/packages/auth0-ai/auth0_ai/session_module/manager.py b/packages/auth0-ai/auth0_ai/session_module/manager.py new file mode 100644 index 0000000..7318417 --- /dev/null +++ b/packages/auth0-ai/auth0_ai/session_module/manager.py @@ -0,0 +1,181 @@ +from __future__ import annotations +from typing import Any, Dict, Optional +import jwt +import time + +from .storage.base_store import BaseStore +from .storage.local_store import LocalStore + +class SessionManager: + """ + Manages session operations including encryption, storage, and retrieval. + Maintains exact compatibility with original implementation. + """ + + def __init__( + self, + auth_client: Any, + use_local_cache: bool = True, + get_sessions=None, + get_session=None, + set_session=None, + delete_session=None, + store: Optional[BaseStore] = None + ): + """ + Initialize session manager with original parameters plus optional store. + + Args: + auth_client: Parent AIAuth instance + use_local_cache: Whether to use local cache (default: True) + get_sessions: Optional custom get_sessions function + get_session: Optional custom get_session function + set_session: Optional custom set_session function + delete_session: Optional custom delete_session function + store: Optional custom store implementation + """ + self.auth_client = auth_client + self.store = store or LocalStore(use_local_cache=use_local_cache) + self.secret_key = auth_client.secret_key + + # Custom function handlers + self.get_sessions = get_sessions + self.get_session = get_session + self.set_session = set_session + self.delete_session = delete_session + + # Original interface methods with exact same names and signatures + def _get_stored_sessions(self) -> Any: + """Get all stored session IDs""" + if hasattr(self, 'get_sessions') and self.get_sessions: + return self.get_sessions() + return self.store.get_stored_sessions() + + def _get_stored_session(self, user_id: str) -> str: + """Get a specific stored session""" + if hasattr(self, 'get_session') and self.get_session: + return self.get_session() + return self.store.get_stored_session(user_id) + + def _set_stored_session(self, user_id: str, encrypted_session_data: str) -> None: + """Store a session""" + if hasattr(self, 'set_session') and self.set_session: + self.set_session() + else: + self.store.set_stored_session(user_id, encrypted_session_data) + + def _delete_stored_session(self, user_id: str) -> None: + """Delete a stored session""" + if hasattr(self, 'delete_session') and self.delete_session: + self.delete_session() + else: + self.store.delete_stored_session(user_id) + + # Session encryption and management methods (from original auth_client.py) + async def set_encrypted_session(self, token_data: dict, state: str | None = None) -> str: + """Create or update encrypted session""" + id_token = token_data.get("id_token", "") + if id_token: + try: + decoded_id_token = await self.auth_client.token_verifier.verify_signature(id_token) + user_id = decoded_id_token.get("sub") + if not user_id: + raise ValueError("ID token missing 'sub' claim.") + except Exception as e: + raise ValueError(f"Invalid ID token: {str(e)}") + else: + user_id = self.auth_client.state_store[state].get("user_id") if state else None + + existing_encrypted_session = self._get_stored_session(user_id) + existing_linked_connections = {} + existing_refresh_token = {} + + if existing_encrypted_session: + existing_session = self.get_encrypted_session(user_id) + existing_refresh_token = existing_session.get("tokens", {}).get("refresh_token", None) + existing_linked_connections = existing_session.get("linked_connections") + + session_data = { + "user_id": user_id, + "tokens": self._get_token_set(token_data, existing_refresh_token), + "linked_connections": self._get_linked_details(token_data, existing_linked_connections) + } + + encrypted_session_data = jwt.encode(session_data, self.secret_key, algorithm="HS256") + self._set_stored_session(user_id, encrypted_session_data) + + if state: + self.auth_client.state_store[state] = {"user_id": user_id} + + return encrypted_session_data + + def get_encrypted_session(self, user_id: str) -> Dict[str, Any]: + """Retrieve and decrypt session data""" + encrypted_session = self._get_stored_session(user_id) + + if not encrypted_session: + return {"not found"} + + try: + decoded_data = jwt.decode(encrypted_session, self.secret_key, algorithms=["HS256"]) + + token_expiry = decoded_data.get("tokens", {}).get("expires_at", {}).get("epoch") + if token_expiry > int(time.time()): + return decoded_data + + refresh_token = decoded_data.get("tokens", {}).get("refresh_token") + if refresh_token: + self._update_encrypted_session(user_id, refresh_token) + else: + self._delete_stored_session(user_id) + return {"session expired"} + + except jwt.ExpiredSignatureError: + return {"Session cookie has expired."} + except jwt.InvalidTokenError: + return {"Invalid session."} + + def _update_encrypted_session(self, user_id: str, refresh_token: str) -> None: + """Update session with refreshed tokens""" + token_manager = self.auth_client.token_manager + updated_tokens = token_manager.refresh_token(refresh_token=refresh_token) + if updated_tokens: + self.set_encrypted_session(updated_tokens) + + + def get_session_details(self, user_id: str) -> Dict[str, Any]: + """Get session details for user""" + if user_id in self._get_stored_sessions(): + return self.get_encrypted_session(user_id) + return {"user_id not found in session store"} + + def get_session(self, user: Any) -> Dict[str, Any]: + """Get session for user object""" + if user.user_id in self._get_stored_sessions(): + return self.get_encrypted_session(user.user_id) + return {"user_id not found in session store"} + + def _get_token_set(self, token_data: dict, existing_refresh_token: str | None = None) -> dict: + """Format token data with expiry time""" + return { + "access_token": token_data.get("access_token"), + "expires_at": {"epoch": int(time.time()) + token_data["expires_in"]}, + "refresh_token": token_data.get("refresh_token", existing_refresh_token), + "id_token": token_data.get("id_token"), + "scope": token_data.get("scope"), + } + + def _get_linked_details(self, token_data: dict, existing_linked_connections: list[str] | None = None) -> list[str]: + """Extract linked connections from token data""" + linked_connections = set(existing_linked_connections or []) + + for item in token_data.get("authorization_details", []): + if item.get("type") == "account_linking": + link_with = item.get("linkParams", {}).get("link_with") + if link_with: + linked_connections.add(link_with) + + return list(linked_connections) + + + diff --git a/packages/auth0-ai/auth0_ai/session_module/storage/__init__.py b/packages/auth0-ai/auth0_ai/session_module/storage/__init__.py new file mode 100644 index 0000000..bfc0e98 --- /dev/null +++ b/packages/auth0-ai/auth0_ai/session_module/storage/__init__.py @@ -0,0 +1,7 @@ +""" +Session Storage Implementations +""" +from .base_store import BaseStore +from .local_store import LocalStore + +__all__ = ["BaseStore", "LocalStore"] diff --git a/packages/auth0-ai/auth0_ai/session_module/storage/base_store.py b/packages/auth0-ai/auth0_ai/session_module/storage/base_store.py new file mode 100644 index 0000000..5e57a74 --- /dev/null +++ b/packages/auth0-ai/auth0_ai/session_module/storage/base_store.py @@ -0,0 +1,45 @@ +from __future__ import annotations +from abc import ABC, abstractmethod +from typing import List + +class BaseStore(ABC): + """ + Abstract base class defining the interface for session storage implementations. + All storage implementations must inherit from this class and implement + all abstract methods. + """ + @abstractmethod + def get_stored_sessions(self) -> List[str]: + """ + Get all stored session IDs. + Returns: + List of session IDs + """ + pass + @abstractmethod + def get_stored_session(self, user_id: str) -> str | None: + """ + Get a specific stored session. + Args: + user_id: The ID of the user whose session to retrieve + Returns: + The session data if found, None otherwise + """ + pass + @abstractmethod + def set_stored_session(self, user_id: str, encrypted_session_data: str) -> None: + """ + Store a session. + Args: + user_id: The ID of the user whose session to store + encrypted_session_data: The encrypted session data to store + """ + pass + @abstractmethod + def delete_stored_session(self, user_id: str) -> None: + """ + Delete a stored session. + Args: + user_id: The ID of the user whose session to delete + """ + pass \ No newline at end of file diff --git a/packages/auth0-ai/auth0_ai/session_module/storage/local_store.py b/packages/auth0-ai/auth0_ai/session_module/storage/local_store.py new file mode 100644 index 0000000..5a8540d --- /dev/null +++ b/packages/auth0-ai/auth0_ai/session_module/storage/local_store.py @@ -0,0 +1,51 @@ +from __future__ import annotations +import shelve +import os +from typing import List + +from .base_store import BaseStore + +class LocalStore(BaseStore): + """ + Local storage implementation using Python's shelve module. + This is the default storage mechanism, maintaining the original implementation's behavior. + """ + + def __init__(self, file_path: str = ".sessions_cache", use_local_cache: bool = True): + """ + Initialize local store. + + Args: + file_path: Path to the shelve file (default: ".sessions_cache") + use_local_cache: Flag to determine if local cache should be used (default: True) + """ + self.file_path = file_path + self.use_local_cache = use_local_cache or os.environ.get("AUTH0_USE_LOCAL_CACHE", True) + + def get_stored_sessions(self) -> List[str]: + """Get all stored session IDs""" + if self.use_local_cache: + with shelve.open(self.file_path) as sessions: + return list(sessions.keys()) + return [] + + def get_stored_session(self, user_id: str) -> str | None: + """Get a specific stored session""" + if self.use_local_cache: + with shelve.open(self.file_path) as sessions: + return sessions.get(user_id) + return None + + def set_stored_session(self, user_id: str, encrypted_session_data: str) -> None: + """Store a session""" + if self.use_local_cache: + with shelve.open(self.file_path) as sessions: + sessions[user_id] = encrypted_session_data + sessions.sync() + + def delete_stored_session(self, user_id: str) -> None: + """Delete a stored session""" + if self.use_local_cache: + with shelve.open(self.file_path) as sessions: + if user_id in sessions: + del sessions[user_id] diff --git a/packages/auth0-ai/auth0_ai/state/__init__.py b/packages/auth0-ai/auth0_ai/state/__init__.py new file mode 100644 index 0000000..13635f8 --- /dev/null +++ b/packages/auth0-ai/auth0_ai/state/__init__.py @@ -0,0 +1,9 @@ +""" +Auth0 AI State Management Module +Internal module for handling authentication and linking state. +""" +from .base_state import BaseState +from .login_state import LoginState +from .link_state import LinkState + +__all__ = ["BaseState", "LoginState", "LinkState"] \ No newline at end of file diff --git a/packages/auth0-ai/auth0_ai/state/base_state.py b/packages/auth0-ai/auth0_ai/state/base_state.py new file mode 100644 index 0000000..9294aab --- /dev/null +++ b/packages/auth0-ai/auth0_ai/state/base_state.py @@ -0,0 +1,38 @@ +from __future__ import annotations +import asyncio +from typing import Any, Dict +from abc import ABC, abstractmethod + +class BaseState(ABC): + """ + Base class for state management in authentication flows. + """ + def __init__(self, state_store: Dict[str, Dict[str, Any]], state: str): + """ + Initialize base state tracker. + Args: + state_store: Reference to the global state store + state: Unique state identifier for this flow + """ + self.state_store = state_store + self.state = state + + @abstractmethod + def is_completed(self) -> bool: + """Check if flow is completed""" + pass + @abstractmethod + def get_user(self) -> Any: + """Get user information after flow completion""" + pass + @abstractmethod + def complete(self, user_id: str) -> None: + """Mark flow as complete with user information""" + pass + def terminate(self) -> None: + """Clean up state data""" + if self.state in self.state_store: + del self.state_store[self.state] + async def _sleep(self, seconds: float) -> None: + """Async sleep helper""" + await asyncio.sleep(seconds) \ No newline at end of file diff --git a/packages/auth0-ai/auth0_ai/state/link_state.py b/packages/auth0-ai/auth0_ai/state/link_state.py new file mode 100644 index 0000000..9a7d862 --- /dev/null +++ b/packages/auth0-ai/auth0_ai/state/link_state.py @@ -0,0 +1,65 @@ +from __future__ import annotations +import time +from typing import Any, Dict, Optional +from .base_state import BaseState + +class LinkState(BaseState): + """ + Handles the state management for the account linking flow. + """ + def __init__(self, state_store: Dict[str, Dict[str, Any]], state: str): + """ + Initialize link state tracker. + Args: + state_store: Reference to the global state store + state: Unique state identifier for this linking attempt + """ + super().__init__(state_store, state) + self.start_time = time.time() + self.timeout = 60 # Linking timeout in seconds + + def is_completed(self) -> bool: + """Check if linking flow is completed""" + return self.state_store[self.state].get("is_completed", False) + def get_user(self) -> str: + """Get user information after linking completion""" + return self.state_store.get(self.state, "login failed!").get("user_id") + def set_user(self, user_id: str) -> None: + """ + Set the primary user ID for linking. + Args: + user_id: ID of the user initiating the link + """ + is_completed = self.state_store[self.state].get("is_completed", False) + self.state_store[self.state] = { + "is_completed": is_completed, + "user_id": user_id + } + + def complete(self, user_id: str) -> None: + """ + Mark linking as complete with user information. + Args: + user_id: ID of the linked user + """ + self.state_store[self.state] = { + "user_id": user_id, + "is_completed": True + } + + async def wait_for_completion(self) -> Optional[str]: + """ + Wait for linking completion or timeout. + Returns: + User ID if successful, None if timeout or failure + """ + while not self.is_completed(): + if time.time() > self.start_time + self.timeout: + self.terminate() + return None + await self._sleep(0.25) # Small delay between checks + user_id = self.get_user() + self.terminate() + if user_id == "login failed!": + return None + return user_id \ No newline at end of file diff --git a/packages/auth0-ai/auth0_ai/state/login_state.py b/packages/auth0-ai/auth0_ai/state/login_state.py new file mode 100644 index 0000000..0f27891 --- /dev/null +++ b/packages/auth0-ai/auth0_ai/state/login_state.py @@ -0,0 +1,63 @@ +from __future__ import annotations +import time +from typing import Any, Dict, Optional + +from .base_state import BaseState + +class LoginState(BaseState): + """ + Handles the state management for the login flow. + """ + + def __init__(self, state_store: Dict[str, Dict[str, Any]], state: str): + """ + Initialize login state tracker. + + Args: + state_store: Reference to the global state store + state: Unique state identifier for this login attempt + """ + super().__init__(state_store, state) + self.start_time = time.time() + self.timeout = 60 # Login timeout in seconds + + def is_completed(self) -> bool: + """Check if login flow is completed""" + return self.state_store[self.state].get("is_completed", False) + + def get_user(self) -> str | Dict[str, Any]: + """Get user information after login completion""" + return self.state_store.get(self.state, "login failed!") + + def complete(self, user_id: str) -> None: + """ + Mark login as complete with user information. + + Args: + user_id: ID of the authenticated user + """ + self.state_store[self.state] = { + "user_id": user_id, + "is_completed": True + } + + async def wait_for_completion(self) -> Optional[str]: + """ + Wait for login completion or timeout. + + Returns: + User ID if successful, None if timeout or failure + """ + while not self.is_completed(): + if time.time() > self.start_time + self.timeout: + self.terminate() + return None + await self._sleep(0.25) # Small delay between checks + + user_data = self.get_user() + self.terminate() + + if user_data == "login failed!": + return None + + return user_data.get("user_id") \ No newline at end of file diff --git a/packages/auth0-ai/auth0_ai/token_module/__init__.py b/packages/auth0-ai/auth0_ai/token_module/__init__.py new file mode 100644 index 0000000..5ca6543 --- /dev/null +++ b/packages/auth0-ai/auth0_ai/token_module/__init__.py @@ -0,0 +1,6 @@ +""" +Auth0 AI Token Management Module +Internal module for handling token operations and lifecycle. +""" +from .manager import TokenManager +__all__ = ["TokenManager"] diff --git a/packages/auth0-ai/auth0_ai/token_module/manager.py b/packages/auth0-ai/auth0_ai/token_module/manager.py new file mode 100644 index 0000000..100eab3 --- /dev/null +++ b/packages/auth0-ai/auth0_ai/token_module/manager.py @@ -0,0 +1,183 @@ +from __future__ import annotations +from typing import Any, Dict +import time +from auth0.authentication import GetToken +from auth0.authentication.async_token_verifier import AsyncAsymmetricSignatureVerifier + + +class TokenManager: + """ + Manages token operations, including exchange, refresh, and validation. + """ + def __init__(self, auth_client: Any): + """ + Initialize token manager. + Args: + auth_client: Parent AIAuth instance + """ + self.auth_client = auth_client + self.token_verifier = AsyncAsymmetricSignatureVerifier( + jwks_url=f"https://{auth_client.domain}/.well-known/jwks.json" + ) + + def exchange_code_for_tokens(self, code: str) -> Dict[str, Any]: + """ + Exchange authorization code for tokens. + Args: + code: Authorization code from Auth0 + Returns: + Dict containing access token, refresh token, and ID token + """ + get_token = GetToken( + self.auth_client.domain, + self.auth_client.client_id, + self.auth_client.client_secret + ) + return get_token.authorization_code( + code=code, + redirect_uri=self.auth_client.redirect_uri, + grant_type="authorization_code" + ) + + def get_token_set(self, token_data: dict, existing_refresh_token: str | None = None) -> dict: + """ + Format token data with expiry time. + Args: + token_data: Raw token data from Auth0 + existing_refresh_token: Optional existing refresh token to preserve + Returns: + Formatted token data with expiry information + """ + return { + "access_token": token_data.get("access_token"), + "expires_at": {"epoch": int(time.time()) + token_data["expires_in"]}, + "refresh_token": token_data.get("refresh_token", existing_refresh_token), + "id_token": token_data.get("id_token"), + "scope": token_data.get("scope"), + } + + async def verify_id_token(self, id_token: str) -> Dict[str, Any]: + """ + Verify and decode ID token. + Args: + id_token: ID token to verify + Returns: + Decoded token claims + """ + return await self.token_verifier.verify_signature(id_token) + + def refresh_tokens(self, refresh_token: str) -> Dict[str, Any]: + """ + Refresh access token using refresh token. + Args: + refresh_token: Refresh token to use + Returns: + New token set + """ + token_client = GetToken( + self.auth_client.domain, + self.auth_client.client_id, + self.auth_client.client_secret + ) + return token_client.refresh_token(refresh_token=refresh_token) + + + def get_3rd_party_token(self, connection: str) -> dict[str, Any]: + return self.get_upstream_token(connection, self.get_refresh_token()) + + def tokeninfo(self) -> dict[str, Any]: + id_token = self.get_id_token() + self.parent.post() + data: dict[str, Any] = self.parent.get( + url=f"https://{self.parent.domain}/tokeninfo", + headers={"Authorization": f"Bearer {access_token}"}, + ) + return data + + def get_upstream_token( + self, + connection: str, + refresh_token: str, + additional_scopes: str | None = None + ) -> Dict[str, Any]: + """ + Get token for federated connection. + Args: + connection: Name of the connection (e.g., 'github') + refresh_token: Refresh token to use + additional_scopes: Optional additional scopes to request + Returns: + Token for the federated connection + """ + token_client = GetToken( + self.auth_client.domain, + self.auth_client.client_id, + self.auth_client.client_secret + ) + return token_client.federated_connection_access_token( + subject_token_type="urn:ietf:params:oauth:token-type:refresh_token", + subject_token=refresh_token, + requested_token_type="http://auth0.com/oauth/token-type/federated-connection-access-token", + connection=connection, + grant_type="urn:auth0:params:oauth:grant-type:token-exchange:federated-connection-access-token" + ) + + async def get_userinfo(self, access_token: str) -> Dict[str, Any]: + """ + Get user information using access token. + Args: + access_token: Access token to use + Returns: + User profile information + """ + return await self.auth_client.get( + url=f"https://{self.auth_client.domain}/userinfo", + headers={"Authorization": f"Bearer {access_token}"} + ) + + async def get_tokeninfo(self, id_token: str, access_token: str) -> Dict[str, Any]: + """ + Get detailed token information. + Args: + id_token: ID token + access_token: Access token + Returns: + Detailed token information + """ + return await self.auth_client.get( + url=f"https://{self.auth_client.domain}/tokeninfo", + headers={"Authorization": f"Bearer {access_token}"} + ) + + def validate_tokens(self, token_data: Dict[str, Any]) -> bool: + """ + Check if tokens are still valid. + Args: + token_data: Token data to validate + Returns: + True if tokens are valid, False otherwise + """ + if not token_data.get("expires_at"): + return False + expiry = token_data["expires_at"].get("epoch", 0) + return time.time() < expiry + + # Session Token Methods (used in User.py) + def get_id_token(self, user_id: str) -> Dict[str, Any]: + if user_id in self.auth_client.session_manager._get_stored_sessions(): + return (self.auth_client.session_manager.get_encrypted_session(user_id).get("tokens").get("access_token")) + else: + return {"user_id not found in session store"} + + def get_refresh_token(self, user_id: str) -> Dict[str, Any]: + if user_id in self.auth_client.session_manager._get_stored_sessions(): + return (self.auth_client.session_manager.get_encrypted_session(user_id).get("tokens").get("refresh_token")) + else: + return {"user_id not found in session store"} + + def get_access_token(self, user_id: str) -> Dict[str, Any]: + if user_id in self.auth_client.session_manager._get_stored_sessions(): + return (self.auth_client.session_manager.get_encrypted_session(user_id).get("tokens").get("access_token")) + else: + return {"user_id not found in session store"} + diff --git a/packages/auth0-ai/auth0_ai/utils/__init__.py b/packages/auth0-ai/auth0_ai/utils/__init__.py new file mode 100644 index 0000000..0a63561 --- /dev/null +++ b/packages/auth0-ai/auth0_ai/utils/__init__.py @@ -0,0 +1,7 @@ +""" +Auth0 AI Utilities Module +Provides utility functions and helpers for URL building and other common operations. +""" +from .url_builder import URLBuilder + +__all__ = ["URLBuilder"] diff --git a/packages/auth0-ai/auth0_ai/utils/url_builder.py b/packages/auth0-ai/auth0_ai/utils/url_builder.py new file mode 100644 index 0000000..0bf7a54 --- /dev/null +++ b/packages/auth0-ai/auth0_ai/utils/url_builder.py @@ -0,0 +1,117 @@ +from __future__ import annotations +from typing import Any, Dict +import json +from urllib.parse import urlencode +from auth0.authentication.pushed_authorization_requests import PushedAuthorizationRequests + +class URLBuilder: + """ + Handles construction of Auth0 authorization URLs and PAR requests. + Maintains original URL building logic from auth_client.py + """ + def __init__(self, auth_client: Any): + """ + Initialize URL builder. + Args: + auth_client: Parent AIAuth instance + """ + self.auth_client = auth_client + + def get_authorize_url( + self, + state: str, + connection: str | None = None, + scope: str | None = None, + additional_scopes: str | None = None, + **kwargs + ) -> str: + """ + Generate authorization URL. + Args: + state: State parameter for CSRF protection + connection: Auth0 connection to use + scope: OAuth scope + additional_scopes: Additional connection-specific scopes + **kwargs: Additional parameters to include in the URL + Returns: + Complete authorization URL + """ + # Base URL parameters + params = { + "response_type": "code", + "client_id": self.auth_client.client_id, + "redirect_uri": self.auth_client.redirect_uri, + "grant_type": "authorization_code", + "state": state + } + # Add optional parameters + if connection is not None: + params["connection"] = connection + if scope is not None: + params["scope"] = scope + if additional_scopes is not None: + params["connection_scope"] = additional_scopes + # Add any additional custom arguments + params.update(kwargs) + # Construct URL + query_string = urlencode(params) + return f"https://{self.auth_client.domain}/authorize?{query_string}" + + def get_authorize_par_url(self, state: str, request_uri: str) -> str: + """ + Generate PAR authorization URL. + Args: + state: State parameter for CSRF protection + request_uri: PAR request URI + Returns: + Complete PAR authorization URL + """ + params = { + "client_id": self.auth_client.client_id, + "state": state, + "request_uri": request_uri + } + query_string = urlencode(params) + return f"https://{self.auth_client.domain}/authorize?{query_string}" + async def create_par_request( + self, + state: str, + connection: str, + id_token: str, + scope: str | None = None, + **kwargs + ) -> Dict[str, Any]: + """ + Create a pushed authorization request. + Args: + state: State parameter for CSRF protection + connection: Connection to link + id_token: ID token for the primary user + scope: OAuth scope + **kwargs: Additional parameters + Returns: + PAR response containing request_uri + """ + par_client = PushedAuthorizationRequests( + self.auth_client.domain, + self.auth_client.client_id, + self.auth_client.client_secret + ) + # Prepare authorization details for account linking + auth_details = [{ + "type": "link_account", + "requested_connection": connection + }] + # Build PAR request + par_request = { + "response_type": "code", + "nonce": kwargs.get("nonce", "mynonce"), + "redirect_uri": self.auth_client.redirect_uri, + "audience": kwargs.get("audience", "my-account"), + "state": state, + "authorization_details": json.dumps(auth_details), + "scope": scope or "openid profile", + "id_token_hint": id_token, + **kwargs + } + return await par_client.pushed_authorization_request(**par_request) \ No newline at end of file From d7961d000521c666bc3469f73c58e4e3c63e87e4 Mon Sep 17 00:00:00 2001 From: mustafadeel Date: Thu, 20 Feb 2025 00:15:02 -0500 Subject: [PATCH 11/11] added unlink, https support and fixed some other things --- packages/auth0-ai/README.md | 2 +- .../auth0-ai/auth0_ai/auth/auth_client.py | 111 ++++++++++++------ packages/auth0-ai/auth0_ai/auth/user.py | 21 +++- .../auth0-ai/auth0_ai/server/auth_server.py | 62 ++++++++-- .../auth0_ai/session_module/manager.py | 40 +++---- .../auth0-ai/auth0_ai/token_module/manager.py | 6 +- 6 files changed, 163 insertions(+), 79 deletions(-) diff --git a/packages/auth0-ai/README.md b/packages/auth0-ai/README.md index 4d221c4..dfe3b5f 100644 --- a/packages/auth0-ai/README.md +++ b/packages/auth0-ai/README.md @@ -5,7 +5,7 @@ This package provides base methods to use Auth0 with your AI use cases. ## Installation ```bash -pip install git+https://github.com/mustafadeel/auth0-ai-python.git@main#subdirectory=packages/auth0-ai +pip install git+https://github.com/mustafadeel/auth0-ai-python.git@mod_struct#subdirectory=packages/auth0-ai ``` ## Running Tests diff --git a/packages/auth0-ai/auth0_ai/auth/auth_client.py b/packages/auth0-ai/auth0_ai/auth/auth_client.py index e02d377..d13d4d5 100644 --- a/packages/auth0-ai/auth0_ai/auth/auth_client.py +++ b/packages/auth0-ai/auth0_ai/auth/auth_client.py @@ -10,12 +10,12 @@ from .base import BaseAuth from .user import User -from server.auth_server import AuthServer -from token_module.manager import TokenManager -from session_module.manager import SessionManager -from state.login_state import LoginState -from state.link_state import LinkState -from utils.url_builder import URLBuilder +from auth0_ai.server.auth_server import AuthServer +from auth0_ai.token_module.manager import TokenManager +from auth0_ai.session_module.manager import SessionManager +from auth0_ai.state.login_state import LoginState +from auth0_ai.state.link_state import LinkState +from auth0_ai.utils.url_builder import URLBuilder class AIAuth(BaseAuth): """Main authentication class that orchestrates the auth flow""" @@ -112,39 +112,24 @@ async def link( state = self._generate_state() link_state = LinkState(self.state_store, state) link_state.set_user(primary_user_id) + + auth_url = self.url_builder.get_authorize_url( + state = state, + scope = "link_account", + audience = "my-account", + requested_connection = connection, + requested_connection_scope = scope, + id_token_hint = id_token, + client_id = self.client_id, + redirect_uri = self.redirect_uri, + ) + try: - # Create PAR request - par_client = PushedAuthorizationRequests( - self.domain, - self.client_id, - self.client_secret - ) - par_response = await par_client.pushed_authorization_request( - response_type="code", - nonce=kwargs.get("nonce", "mynonce"), - redirect_uri=self.redirect_uri, - audience=kwargs.get("audience", "my-account"), - state=state, - authorization_details=json.dumps([ - {"type": "link_account", "requested_connection": connection} - ]), - scope=scope or "openid profile", - id_token_hint=id_token, - **kwargs - ) - request_uri = par_response.get('request_uri') - if not request_uri: - raise ValueError("Failed to get request_uri from PAR response") - # Generate authorization URL with PAR - auth_url = self.url_builder.get_authorize_par_url( - state=state, - request_uri=request_uri - ) - # Open browser for linking try: webbrowser.open(auth_url) except webbrowser.Error: print(f"Please navigate here: {auth_url}") + # Wait for linking completion user_id = await link_state.wait_for_completion() return { @@ -157,14 +142,62 @@ async def link( "is_successful": False, "user_id": primary_user_id, } - - def get_session_details(self, user_id: str) -> Dict[str, Any]: - """Get session details for a user""" - return self.session_manager.get_session_details(user_id) + async def unlink( + self, + primary_user_id: str, + connection: str, + id_token: str, + **kwargs + ) -> Dict[str, Any]: + """ + Handle account linking flow. + Args: + primary_user_id: ID of the user initiating the link + connection: Connection to link + id_token: ID token of the primary user + scope: OAuth scope + **kwargs: Additional parameters for authorization + Returns: + Dict containing link status and user information + """ + state = self._generate_state() + link_state = LinkState(self.state_store, state) + link_state.set_user(primary_user_id) + + auth_url = self.url_builder.get_authorize_url( + state = state, + scope = "unlink_account", + audience = "my-account", + requested_connection = connection, + id_token_hint = id_token, + client_id = self.client_id, + redirect_uri = self.redirect_uri, + ) + + try: + try: + webbrowser.open(auth_url) + except webbrowser.Error: + print(f"Please navigate here: {auth_url}") + + # Wait for linking completion + user_id = await link_state.wait_for_completion() + return { + "is_successful": bool(user_id), + "user_id": user_id or primary_user_id + } + except Exception as error: + print(f"Error during unlinking: {error}") + return { + "is_successful": False, + "user_id": primary_user_id, + } + def get_session(self, user: User) -> Dict[str, Any]: """Get session for a user object""" - return self.session_manager.get_session(user) + return self.session_manager.get_session(user = user) + def get_upstream_token( self, diff --git a/packages/auth0-ai/auth0_ai/auth/user.py b/packages/auth0-ai/auth0_ai/auth/user.py index 429449a..880a013 100644 --- a/packages/auth0-ai/auth0_ai/auth/user.py +++ b/packages/auth0-ai/auth0_ai/auth/user.py @@ -35,7 +35,22 @@ async def link(self, connection: str, scope: str | None = None, **kwargs) -> Dic scope=scope, **kwargs ) - + async def unlink(self, connection: str, scope: str | None = None, **kwargs) -> Dict[str, Any]: + """ + UnLink an existing authentication provider to this user account. + Args: + connection: The name of the connection to unlink (e.g., 'github', 'google') + **kwargs: Additional parameters for the unlinking process + Returns: + Dict containing unlink status and user information + """ + return await self._auth_client.unlink( + primary_user_id=self.user_id, + connection=connection, + id_token=self.get_id_token(), + scope=scope, + **kwargs + ) def get_id_token(self) -> str: """Get the user's ID token""" return self._auth_client.token_manager.get_id_token(self.user_id) @@ -59,14 +74,14 @@ def get_3rd_party_token(self, connection: str) -> Dict[str, Any]: refresh_token = self.get_refresh_token() return self._auth_client.get_upstream_token(connection, refresh_token) - async def get_profile(self) -> Dict[str, Any]: + def get_profile(self) -> Dict[str, Any]: """ Get the user's profile information. Returns: Dict containing user profile data """ access_token = self.get_access_token() - return await self._auth_client.token_manager.get_userinfo(access_token) + return self._auth_client.token_manager.get_userinfo(access_token) async def get_token_info(self) -> Dict[str, Any]: """ diff --git a/packages/auth0-ai/auth0_ai/server/auth_server.py b/packages/auth0-ai/auth0_ai/server/auth_server.py index 15e8d39..a371042 100644 --- a/packages/auth0-ai/auth0_ai/server/auth_server.py +++ b/packages/auth0-ai/auth0_ai/server/auth_server.py @@ -4,6 +4,7 @@ from typing import Any import uvicorn +import os from fastapi import FastAPI from .routes import setup_routes @@ -27,21 +28,56 @@ def __init__(self, auth_client: Any): parsed_uri = urllib.parse.urlparse(auth_client.redirect_uri) self.host = parsed_uri.hostname self.port = parsed_uri.port + self.protocol = urllib.parse.urlparse(auth_client.redirect_uri).scheme # Setup routes with dependencies setup_routes(self.app, auth_client) self.start() - def start(self) -> None: - """Start the FastAPI server in a daemon thread.""" - server_thread = threading.Thread( - target=uvicorn.run, - args=(self.app,), - kwargs={ - "host": self.host or "localhost", - "port": self.port or "3000", - "log_level": "info" - }, - daemon=False - ) - server_thread.start() + def _is_valid_file(self,file_path) -> bool: + """Check if the file exists and is accessible.""" + valid = False + try: + valid = os.path.isfile(file_path) and os.access(file_path, os.R_OK) + return valid + except Exception as e: + return valid + + + def start(self): + """Runs FastAPI as the middleware inside a separate thread.""" + if (self.protocol == "https"): + + ssl_keyfile = os.getenv("AUTH0_SSL_KEYFILE") + ssl_certfile = os.getenv("AUTH0_SSL_CERTFILE") + + if not self._is_valid_file(ssl_keyfile) or not self._is_valid_file(ssl_certfile): + raise ValueError( + "AUTH0_SSL_KEYFILE and AUTH0_SSL_CERTFILE environment variables must be set with valid file paths for HTTPS.") + + server_thread = threading.Thread( + target=uvicorn.run, + args=(self.app,), + kwargs={ + "host": self.host, + "port": self.port, + "ssl_keyfile": ssl_keyfile, # Path to private key + "ssl_certfile": ssl_certfile, # Path to certificate + "log_level": "error"}, + daemon=True # Daemon mode so it exits when the main thread exits + ) + else: + server_thread = threading.Thread( + target=uvicorn.run, + args=(self.app,), + kwargs={ + "host": self.host, + "port": self.port, + "log_level": "info"}, + daemon=True # Daemon mode so it exits when the main thread exits + ) + try: + server_thread.start() + except Exception as e: + print(f"Error starting middleware server: {str(e)}") + raise e \ No newline at end of file diff --git a/packages/auth0-ai/auth0_ai/session_module/manager.py b/packages/auth0-ai/auth0_ai/session_module/manager.py index 7318417..2b2833e 100644 --- a/packages/auth0-ai/auth0_ai/session_module/manager.py +++ b/packages/auth0-ai/auth0_ai/session_module/manager.py @@ -16,10 +16,10 @@ def __init__( self, auth_client: Any, use_local_cache: bool = True, - get_sessions=None, - get_session=None, - set_session=None, - delete_session=None, + get_ext_sessions=None, + get_ext_session=None, + set_ext_session=None, + delete_ext_session=None, store: Optional[BaseStore] = None ): """ @@ -28,10 +28,10 @@ def __init__( Args: auth_client: Parent AIAuth instance use_local_cache: Whether to use local cache (default: True) - get_sessions: Optional custom get_sessions function - get_session: Optional custom get_session function - set_session: Optional custom set_session function - delete_session: Optional custom delete_session function + get_ext_sessions: Optional custom get_sessions function + get_ext_session: Optional custom get_session function + set_ext_session: Optional custom set_session function + delete_ext_session: Optional custom delete_session function store: Optional custom store implementation """ self.auth_client = auth_client @@ -39,35 +39,35 @@ def __init__( self.secret_key = auth_client.secret_key # Custom function handlers - self.get_sessions = get_sessions - self.get_session = get_session - self.set_session = set_session - self.delete_session = delete_session + self.get_ext_sessions = get_ext_sessions + self.get_ext_session = get_ext_session + self.set_ext_session = set_ext_session + self.delete_ext_session = delete_ext_session # Original interface methods with exact same names and signatures def _get_stored_sessions(self) -> Any: """Get all stored session IDs""" - if hasattr(self, 'get_sessions') and self.get_sessions: - return self.get_sessions() + if hasattr(self, 'get_ext_sessions') and self.get_ext_sessions: + return self.get_ext_sessions() return self.store.get_stored_sessions() def _get_stored_session(self, user_id: str) -> str: """Get a specific stored session""" - if hasattr(self, 'get_session') and self.get_session: - return self.get_session() + if hasattr(self, 'get_ext_session') and self.get_ext_session: + return self.get_ext_session() return self.store.get_stored_session(user_id) def _set_stored_session(self, user_id: str, encrypted_session_data: str) -> None: """Store a session""" - if hasattr(self, 'set_session') and self.set_session: - self.set_session() + if hasattr(self, 'set_ext_session') and self.set_ext_session: + self.set_ext_session() else: self.store.set_stored_session(user_id, encrypted_session_data) def _delete_stored_session(self, user_id: str) -> None: """Delete a stored session""" - if hasattr(self, 'delete_session') and self.delete_session: - self.delete_session() + if hasattr(self, 'delete_ext_session') and self.delete_ext_session: + self.delete_ext_session() else: self.store.delete_stored_session(user_id) diff --git a/packages/auth0-ai/auth0_ai/token_module/manager.py b/packages/auth0-ai/auth0_ai/token_module/manager.py index 100eab3..63c7a8e 100644 --- a/packages/auth0-ai/auth0_ai/token_module/manager.py +++ b/packages/auth0-ai/auth0_ai/token_module/manager.py @@ -122,7 +122,7 @@ def get_upstream_token( grant_type="urn:auth0:params:oauth:grant-type:token-exchange:federated-connection-access-token" ) - async def get_userinfo(self, access_token: str) -> Dict[str, Any]: + def get_userinfo(self, access_token: str) -> Dict[str, Any]: """ Get user information using access token. Args: @@ -130,7 +130,7 @@ async def get_userinfo(self, access_token: str) -> Dict[str, Any]: Returns: User profile information """ - return await self.auth_client.get( + return self.auth_client.get( url=f"https://{self.auth_client.domain}/userinfo", headers={"Authorization": f"Bearer {access_token}"} ) @@ -165,7 +165,7 @@ def validate_tokens(self, token_data: Dict[str, Any]) -> bool: # Session Token Methods (used in User.py) def get_id_token(self, user_id: str) -> Dict[str, Any]: if user_id in self.auth_client.session_manager._get_stored_sessions(): - return (self.auth_client.session_manager.get_encrypted_session(user_id).get("tokens").get("access_token")) + return (self.auth_client.session_manager.get_encrypted_session(user_id).get("tokens").get("id_token")) else: return {"user_id not found in session store"}