From 3f0b6118898bb4282bf29a748415c53e294f8629 Mon Sep 17 00:00:00 2001
From: Jay Prakash <0freerunning@gmail.com>
Date: Thu, 3 Apr 2025 03:49:02 +0530
Subject: [PATCH 01/74] Add user verification and password reset
---
backend/add_users_to_db.py | 2 +
backend/app/auth/dependencies.py | 52 +++-
backend/app/auth/routers.py | 175 ++++++++++++-
backend/app/auth/schemas.py | 2 +
backend/app/auth/utils.py | 137 ++++++++++
backend/app/config.py | 7 +
backend/app/contextual_mab/routers.py | 10 +-
backend/app/database.py | 9 +
backend/app/email.py | 168 ++++++++++++
backend/app/mab/routers.py | 10 +-
backend/app/messages/routers.py | 10 +-
backend/app/users/models.py | 61 +++++
backend/app/users/routers.py | 35 ++-
backend/app/users/schemas.py | 45 +++-
.../97137b6afb58_added_user_fields.py | 45 ++++
backend/requirements.txt | 2 +
backend/tests/test.env | 8 +
frontend/package-lock.json | 128 +++++++++
frontend/src/app/forgot-password/page.tsx | 163 ++++++++++++
frontend/src/app/login/page.tsx | 42 +--
frontend/src/app/reset-password/page.tsx | 245 ++++++++++++++++++
.../src/app/verification-required/page.tsx | 171 ++++++++++++
frontend/src/app/verify/page.tsx | 159 ++++++++++++
.../src/components/ProtectedComponent.tsx | 14 +-
frontend/src/utils/api.ts | 43 +++
frontend/src/utils/auth.tsx | 114 +++++---
26 files changed, 1773 insertions(+), 84 deletions(-)
create mode 100644 backend/app/auth/utils.py
create mode 100644 backend/app/email.py
create mode 100644 backend/migrations/versions/97137b6afb58_added_user_fields.py
create mode 100644 frontend/src/app/forgot-password/page.tsx
create mode 100644 frontend/src/app/reset-password/page.tsx
create mode 100644 frontend/src/app/verification-required/page.tsx
create mode 100644 frontend/src/app/verify/page.tsx
diff --git a/backend/add_users_to_db.py b/backend/add_users_to_db.py
index 69e5f3f..0aa7183 100644
--- a/backend/add_users_to_db.py
+++ b/backend/add_users_to_db.py
@@ -36,6 +36,8 @@
api_daily_quota=ADMIN_API_DAILY_QUOTA,
created_datetime_utc=datetime.now(timezone.utc),
updated_datetime_utc=datetime.now(timezone.utc),
+ is_active=True,
+ is_verified=True,
)
diff --git a/backend/app/auth/dependencies.py b/backend/app/auth/dependencies.py
index c41aea0..1db3022 100644
--- a/backend/app/auth/dependencies.py
+++ b/backend/app/auth/dependencies.py
@@ -12,7 +12,7 @@
from jwt.exceptions import InvalidTokenError
from sqlalchemy.ext.asyncio import AsyncSession
-from ..config import CHECK_API_LIMIT, DEFAULT_API_QUOTA, DEFAULT_EXPERIMENTS_QUOTA
+from ..config import CHECK_API_LIMIT, DEFAULT_API_QUOTA, DEFAULT_EXPERIMENTS_QUOTA, ENV
from ..database import get_async_session
from ..users.models import (
UserDB,
@@ -20,6 +20,7 @@
get_user_by_api_key,
get_user_by_username,
save_user_to_db,
+ update_user_verification_status,
)
from ..users.schemas import UserCreate
from ..utils import (
@@ -54,6 +55,12 @@ async def authenticate_key(
token = credentials.credentials
try:
user_db = await get_user_by_api_key(token, asession)
+
+ if not user_db.is_active:
+ raise HTTPException(
+ status_code=403, detail="Account is inactive. Please contact support."
+ )
+
return user_db
except UserNotFoundError as e:
raise HTTPException(status_code=403, detail="Invalid API key") from e
@@ -67,12 +74,18 @@ async def authenticate_credentials(
"""
try:
user_db = await get_user_by_username(username, asession)
+
+ if not user_db.is_active:
+ logger.warning(f"Inactive user {username} attempted to login")
+ return None
+
if verify_password_salted_hash(password, user_db.hashed_password):
# hardcode "fullaccess" now, but may use it in the future
return AuthenticatedUser(
username=username,
access_level="fullaccess",
api_key_first_characters=user_db.api_key_first_characters,
+ is_verified=user_db.is_verified,
)
else:
return None
@@ -87,14 +100,21 @@ async def authenticate_or_create_google_user(
asession: AsyncSession,
) -> Optional[AuthenticatedUser]:
"""
- Check if user exists in Db. If not, create user
+ Check if user exists in Db. If not, create user.
+ Google authenticated users are automatically verified.
"""
try:
user_db = await get_user_by_username(google_email, asession)
+
+ if not user_db.is_verified:
+ asession.add(user_db)
+ await update_user_verification_status(user_db, True, asession)
+
return AuthenticatedUser(
username=user_db.username,
access_level="fullaccess",
api_key_first_characters=user_db.api_key_first_characters,
+ is_verified=user_db.is_verified,
)
except UserNotFoundError:
user = UserCreate(
@@ -103,7 +123,7 @@ async def authenticate_or_create_google_user(
api_daily_quota=DEFAULT_API_QUOTA,
)
api_key = generate_key()
- user_db = await save_user_to_db(user, api_key, asession)
+ user_db = await save_user_to_db(user, api_key, asession, is_verified=True)
await update_api_limits(
request.app.state.redis, user_db.username, user_db.api_daily_quota
)
@@ -111,6 +131,7 @@ async def authenticate_or_create_google_user(
username=user_db.username,
access_level="fullaccess",
api_key_first_characters=user_db.api_key_first_characters,
+ is_verified=True,
)
@@ -135,6 +156,13 @@ async def get_current_user(
# fetch user from database
try:
user_db = await get_user_by_username(username, asession)
+
+ if not user_db.is_active:
+ raise HTTPException(
+ status_code=403,
+ detail="Account is inactive. Please contact support.",
+ )
+
return user_db
except UserNotFoundError as err:
raise credentials_exception from err
@@ -142,6 +170,24 @@ async def get_current_user(
raise credentials_exception from err
+async def get_verified_user(
+ user_db: Annotated[UserDB, Depends(get_current_user)],
+) -> UserDB:
+ """
+ Check if the user is verified
+ """
+ if ENV == "testing":
+ return user_db
+
+ if not user_db.is_verified:
+ raise HTTPException(
+ status_code=403,
+ detail="Account not verified. Please check your email to verify "
+ "your account.",
+ )
+ return user_db
+
+
def create_access_token(username: str) -> str:
"""
Create an access token for the user
diff --git a/backend/app/auth/routers.py b/backend/app/auth/routers.py
index 6fc22d6..eec0c4a 100644
--- a/backend/app/auth/routers.py
+++ b/backend/app/auth/routers.py
@@ -1,11 +1,26 @@
-from fastapi import APIRouter, Depends, HTTPException
+from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from fastapi.requests import Request
from fastapi.security import OAuth2PasswordRequestForm
-from google.auth.transport import requests
+from google.auth.transport import requests as google_requests
from google.oauth2 import id_token
+from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import AsyncSession
-from ..database import get_async_session
+from ..database import get_async_session, get_redis
+from ..email import EmailService
+from ..users.models import (
+ UserNotFoundError,
+ get_user_by_username,
+ update_user_password,
+ update_user_verification_status,
+)
+from ..users.schemas import (
+ EmailVerificationRequest,
+ MessageResponse,
+ PasswordResetConfirm,
+ PasswordResetRequest,
+)
+from ..utils import setup_logger
from .config import NEXT_PUBLIC_GOOGLE_LOGIN_CLIENT_ID
from .dependencies import (
authenticate_credentials,
@@ -13,6 +28,11 @@
create_access_token,
)
from .schemas import AuthenticationDetails, GoogleLoginData
+from .utils import (
+ generate_password_reset_token,
+ generate_verification_token,
+ verify_token,
+)
TAG_METADATA = {
"name": "Authentication",
@@ -21,6 +41,9 @@
router = APIRouter(tags=[TAG_METADATA["name"]])
+email_service = EmailService()
+logger = setup_logger()
+
@router.post("/login")
async def login(
@@ -47,6 +70,7 @@ async def login(
token_type="bearer",
access_level=user.access_level,
username=user.username,
+ is_verified=user.is_verified,
)
@@ -58,13 +82,13 @@ async def login_google(
) -> AuthenticationDetails:
"""
Verify google token, check if user exists. If user does not exist, create user
- Return JWT token for user
+ Return JWT token for user. Google users are automatically verified.
"""
try:
idinfo = id_token.verify_oauth2_token(
login_data.credential,
- requests.Request(),
+ google_requests.Request(),
NEXT_PUBLIC_GOOGLE_LOGIN_CLIENT_ID,
)
if idinfo["iss"] not in ["accounts.google.com", "https://accounts.google.com"]:
@@ -87,4 +111,145 @@ async def login_google(
token_type="bearer",
access_level=user.access_level,
username=user.username,
+ is_verified=user.is_verified,
)
+
+
+@router.post("/request-password-reset", response_model=MessageResponse)
+async def request_password_reset(
+ reset_request: PasswordResetRequest,
+ background_tasks: BackgroundTasks,
+ asession: AsyncSession = Depends(get_async_session),
+ redis: Redis = Depends(get_redis),
+) -> MessageResponse:
+ """
+ Request a password reset email
+ """
+ response_msg = (
+ "If an account with this email exists, a password reset link has been sent."
+ )
+
+ try:
+ logger.info(f"Generated password reset token for user {reset_request.username}")
+ user = await get_user_by_username(reset_request.username, asession)
+
+ logger.info(f"User found: {user.username}")
+ if not user:
+ return MessageResponse(message=response_msg)
+
+ token = await generate_password_reset_token(user.user_id, user.username, redis)
+ background_tasks.add_task(
+ email_service.send_password_reset_email, user.username, user.username, token
+ )
+
+ return MessageResponse(message=response_msg)
+ except UserNotFoundError:
+ logger.warning(f"User not found: {reset_request.username}")
+ return MessageResponse(message=response_msg)
+ except Exception as e:
+ logger.exception("An error occurred processing your request")
+ raise HTTPException(
+ status_code=500, detail="An error occurred processing your request"
+ ) from e
+
+
+@router.post("/reset-password", response_model=MessageResponse)
+async def reset_password(
+ reset_data: PasswordResetConfirm,
+ asession: AsyncSession = Depends(get_async_session),
+ redis: Redis = Depends(get_redis),
+) -> MessageResponse:
+ """
+ Reset a user's password with the provided token
+ """
+ is_valid, payload = await verify_token(reset_data.token, "password_reset", redis)
+
+ if not is_valid or "username" not in payload:
+ raise HTTPException(
+ status_code=400, detail="Invalid or expired password reset token"
+ )
+
+ try:
+ user = await get_user_by_username(payload["username"], asession)
+ asession.add(user)
+
+ await update_user_password(user, reset_data.new_password, asession)
+
+ return MessageResponse(message="Your password has been reset successfully")
+ except UserNotFoundError as e:
+ raise HTTPException(status_code=404, detail="User not found") from e
+ except Exception as e:
+ raise HTTPException(
+ status_code=500, detail="An error occurred while resetting your password"
+ ) from e
+
+
+@router.post("/verify-email", response_model=MessageResponse)
+async def verify_email(
+ verification_data: EmailVerificationRequest,
+ asession: AsyncSession = Depends(get_async_session),
+ redis: Redis = Depends(get_redis),
+) -> MessageResponse:
+ """
+ Verify a user's email with the provided token
+ """
+ is_valid, payload = await verify_token(
+ verification_data.token, "verification", redis
+ )
+
+ if not is_valid or "username" not in payload:
+ raise HTTPException(
+ status_code=400, detail="Invalid or expired verification token"
+ )
+
+ try:
+ user = await get_user_by_username(payload["username"], asession)
+
+ asession.add(user)
+ await update_user_verification_status(user, True, asession)
+
+ return MessageResponse(message="Your email has been verified successfully")
+ except UserNotFoundError as e:
+ raise HTTPException(status_code=404, detail="User not found") from e
+ except Exception as e:
+ raise HTTPException(
+ status_code=500, detail="An error occurred while verifying your email"
+ ) from e
+
+
+@router.post("/resend-verification", response_model=MessageResponse)
+async def resend_verification(
+ reset_request: PasswordResetRequest,
+ background_tasks: BackgroundTasks,
+ asession: AsyncSession = Depends(get_async_session),
+ redis: Redis = Depends(get_redis),
+) -> MessageResponse:
+ """
+ Resend verification email
+ """
+ response_msg = (
+ "If an account with this email exists, a password reset link has been sent."
+ )
+
+ try:
+ user = await get_user_by_username(reset_request.username, asession)
+
+ if not user:
+ return MessageResponse(message=response_msg)
+
+ if user.is_verified:
+ return MessageResponse(message="Your account is already verified")
+
+ token = await generate_verification_token(user.user_id, user.username, redis)
+
+ background_tasks.add_task(
+ email_service.send_verification_email, user.username, user.username, token
+ )
+
+ return MessageResponse(message=response_msg)
+ except UserNotFoundError:
+ return MessageResponse(message=response_msg)
+ except Exception as e:
+ raise HTTPException(
+ status_code=500, detail="An error occurred processing your request"
+ ) from e
diff --git a/backend/app/auth/schemas.py b/backend/app/auth/schemas.py
index c8da6d7..9b4b853 100644
--- a/backend/app/auth/schemas.py
+++ b/backend/app/auth/schemas.py
@@ -14,6 +14,7 @@ class AuthenticatedUser(BaseModel):
username: str
access_level: AccessLevel
api_key_first_characters: str
+ is_verified: bool
model_config = ConfigDict(from_attributes=True)
@@ -38,5 +39,6 @@ class AuthenticationDetails(BaseModel):
access_level: AccessLevel
api_key_first_characters: str
username: str
+ is_verified: bool
model_config = ConfigDict(from_attributes=True)
diff --git a/backend/app/auth/utils.py b/backend/app/auth/utils.py
new file mode 100644
index 0000000..d6b2862
--- /dev/null
+++ b/backend/app/auth/utils.py
@@ -0,0 +1,137 @@
+import secrets
+from datetime import datetime, timedelta, timezone
+from typing import Dict, Tuple
+
+import jwt
+from redis.asyncio import Redis
+
+from ..utils import setup_logger
+from .config import JWT_ALGORITHM, JWT_SECRET
+
+# Token settings
+VERIFICATION_TOKEN_EXPIRE_MINUTES = 1440 # 24 hours
+PASSWORD_RESET_TOKEN_EXPIRE_MINUTES = 30 # 30 minutes
+
+logger = setup_logger()
+
+
+async def generate_verification_token(user_id: int, username: str, redis: Redis) -> str:
+ """
+ Generates a verification token for account activation
+
+ Args:
+ user_id: The user's ID
+ username: The user's username (email)
+ redis: Redis connection
+
+ Returns:
+ JWT token for email verification
+ """
+ # Generate JWT token
+ token_jti = secrets.token_hex(16) # Add unique ID to prevent token reuse
+ payload = {
+ "sub": str(user_id),
+ "username": username,
+ "type": "verification",
+ "exp": datetime.now(timezone.utc)
+ + timedelta(minutes=VERIFICATION_TOKEN_EXPIRE_MINUTES),
+ "iat": datetime.now(timezone.utc),
+ "jti": token_jti,
+ }
+
+ token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
+
+ # Store token in Redis with expiry for additional security
+ await redis.set(
+ f"verification_token:{token_jti}",
+ str(user_id),
+ ex=VERIFICATION_TOKEN_EXPIRE_MINUTES * 60,
+ )
+
+ logger.info(f"Generated verification token for user {user_id}")
+ return token
+
+
+async def generate_password_reset_token(
+ user_id: int, username: str, redis: Redis
+) -> str:
+ """
+ Generates a token for password reset
+
+ Args:
+ user_id: The user's ID
+ username: The user's username (email)
+ redis: Redis connection
+
+ Returns:
+ JWT token for password reset
+ """
+ # Generate JWT token
+ token_jti = secrets.token_hex(16) # Add unique ID to prevent token reuse
+ payload = {
+ "sub": str(user_id),
+ "username": username,
+ "type": "password_reset",
+ "exp": datetime.now(timezone.utc)
+ + timedelta(minutes=PASSWORD_RESET_TOKEN_EXPIRE_MINUTES),
+ "iat": datetime.now(timezone.utc),
+ "jti": token_jti,
+ }
+
+ token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
+
+ # Store token in Redis with expiry for additional security
+ await redis.set(
+ f"password_reset_token:{token_jti}",
+ str(user_id),
+ ex=PASSWORD_RESET_TOKEN_EXPIRE_MINUTES * 60,
+ )
+
+ logger.info(f"Generated password reset token for user {user_id}")
+ return token
+
+
+async def verify_token(token: str, token_type: str, redis: Redis) -> Tuple[bool, Dict]:
+ """
+ Verifies a token and returns user information if valid
+
+ Args:
+ token: The JWT token
+ token_type: Either "verification" or "password_reset"
+ redis: Redis connection
+
+ Returns:
+ Tuple of (is_valid, payload)
+ """
+ try:
+ # Decode the token
+ payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
+
+ # Check token type
+ if payload.get("type") != token_type:
+ logger.warning(
+ f"Invalid token type: expected {token_type}, got {payload.get('type')}"
+ )
+ return False, {}
+
+ # Check if token is in Redis (hasn't been used)
+ token_key = f"{token_type}_token:{payload['jti']}"
+ stored_user_id = await redis.get(token_key)
+
+ if not stored_user_id or stored_user_id.decode() != payload["sub"]:
+ logger.warning(
+ "Token validation failed: token not found in Redis or user_id mismatch"
+ )
+ return False, {}
+
+ # If verification successful, invalidate token to prevent reuse
+ await redis.delete(token_key)
+
+ logger.info(
+ f"Successfully verified {token_type} token for user {payload['sub']}"
+ )
+ return True, payload
+
+ except jwt.PyJWTError as e:
+ logger.error(f"JWT token verification error: {str(e)}")
+ return False, {}
diff --git a/backend/app/config.py b/backend/app/config.py
index 89de874..55ec33e 100644
--- a/backend/app/config.py
+++ b/backend/app/config.py
@@ -1,5 +1,6 @@
import os
+ENV = os.environ.get("ENV", "development")
POSTGRES_USER = os.environ.get("POSTGRES_USER", "postgres")
POSTGRES_PASSWORD = os.environ.get("POSTGRES_PASSWORD", "postgres")
POSTGRES_HOST = os.environ.get("POSTGRES_HOST", "localhost")
@@ -17,3 +18,9 @@
DEFAULT_API_QUOTA = int(os.environ.get("DEFAULT_API_QUOTA", 100))
CHECK_API_LIMIT = os.environ.get("CHECK_API_LIMIT", True)
CHECK_EXPERIMENTS_LIMIT = os.environ.get("CHECK_EXPERIMENTS_LIMIT", True)
+
+SES_REGION = os.environ.get("SES_REGION", None)
+SES_SENDER_EMAIL = os.environ.get("SES_SENDER_EMAIL", None)
+FRONTEND_URL = os.environ.get("FRONTEND_URL", None)
+AWS_ACCESS_KEY_ID = os.environ.get("AWS_ACCESS_KEY_ID", None)
+AWS_SECRET_ACCESS_KEY = os.environ.get("AWS_SECRET_ACCESS_KEY", None)
diff --git a/backend/app/contextual_mab/routers.py b/backend/app/contextual_mab/routers.py
index dd56b94..95a6524 100644
--- a/backend/app/contextual_mab/routers.py
+++ b/backend/app/contextual_mab/routers.py
@@ -4,7 +4,7 @@
from fastapi.exceptions import HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
-from ..auth.dependencies import authenticate_key, get_current_user
+from ..auth.dependencies import authenticate_key, get_verified_user
from ..database import get_async_session
from ..models import get_notifications_from_db, save_notifications_to_db
from ..schemas import ContextType, NotificationsResponse, Outcome
@@ -35,7 +35,7 @@
@router.post("/", response_model=ContextualBanditResponse)
async def create_contextual_mabs(
experiment: ContextualBandit,
- user_db: Annotated[UserDB, Depends(get_current_user)],
+ user_db: Annotated[UserDB, Depends(get_verified_user)],
asession: AsyncSession = Depends(get_async_session),
) -> ContextualBanditResponse | HTTPException:
"""
@@ -55,7 +55,7 @@ async def create_contextual_mabs(
@router.get("/", response_model=list[ContextualBanditResponse])
async def get_contextual_mabs(
- user_db: Annotated[UserDB, Depends(get_current_user)],
+ user_db: Annotated[UserDB, Depends(get_verified_user)],
asession: AsyncSession = Depends(get_async_session),
) -> list[ContextualBanditResponse]:
"""
@@ -89,7 +89,7 @@ async def get_contextual_mabs(
@router.get("/{experiment_id}", response_model=ContextualBanditResponse)
async def get_contextual_mab(
experiment_id: int,
- user_db: Annotated[UserDB, Depends(get_current_user)],
+ user_db: Annotated[UserDB, Depends(get_verified_user)],
asession: AsyncSession = Depends(get_async_session),
) -> ContextualBanditResponse | HTTPException:
"""
@@ -116,7 +116,7 @@ async def get_contextual_mab(
@router.delete("/{experiment_id}", response_model=dict)
async def delete_contextual_mab(
experiment_id: int,
- user_db: Annotated[UserDB, Depends(get_current_user)],
+ user_db: Annotated[UserDB, Depends(get_verified_user)],
asession: AsyncSession = Depends(get_async_session),
) -> dict:
"""
diff --git a/backend/app/database.py b/backend/app/database.py
index e944dd2..7265e9e 100644
--- a/backend/app/database.py
+++ b/backend/app/database.py
@@ -2,6 +2,8 @@
from collections.abc import AsyncGenerator, Generator
from typing import ContextManager
+from fastapi import Request
+from redis.asyncio import Redis
from sqlalchemy.engine import URL, Engine, create_engine
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
from sqlalchemy.orm import Session
@@ -81,3 +83,10 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
get_sqlalchemy_async_engine(), expire_on_commit=False
) as async_session:
yield async_session
+
+
+async def get_redis(request: Request) -> Redis:
+ """
+ Returns the Redis client for the request.
+ """
+ return request.app.state.redis
diff --git a/backend/app/email.py b/backend/app/email.py
new file mode 100644
index 0000000..be59269
--- /dev/null
+++ b/backend/app/email.py
@@ -0,0 +1,168 @@
+from typing import Any, Dict, Optional
+
+import boto3
+from botocore.exceptions import ClientError
+from fastapi import HTTPException
+
+from .config import (
+ AWS_ACCESS_KEY_ID,
+ AWS_SECRET_ACCESS_KEY,
+ ENV,
+ FRONTEND_URL,
+ SES_REGION,
+ SES_SENDER_EMAIL,
+)
+from .utils import setup_logger
+
+logger = setup_logger()
+
+
+class EmailService:
+ """Service to send emails via AWS SES"""
+
+ def __init__(
+ self,
+ aws_region: Optional[str] = None,
+ sender_email: Optional[str] = None,
+ aws_access_key_id: Optional[str] = None,
+ aws_secret_access_key: Optional[str] = None,
+ ) -> None:
+ """Initialize the email service with AWS credentials"""
+ self.aws_region = aws_region or SES_REGION
+ self.sender_email = sender_email or SES_SENDER_EMAIL
+
+ session_kwargs: Dict[str, Any] = {"region_name": self.aws_region}
+
+ aws_access_key = aws_access_key_id or AWS_ACCESS_KEY_ID
+ aws_secret_key = aws_secret_access_key or AWS_SECRET_ACCESS_KEY
+
+ if aws_access_key and aws_secret_key:
+ session_kwargs.update(
+ {
+ "aws_access_key_id": aws_access_key,
+ "aws_secret_access_key": aws_secret_key,
+ }
+ )
+
+ self.client = boto3.client("ses", **session_kwargs)
+
+ async def send_verification_email(
+ self, email: str, username: str, token: str
+ ) -> Dict[str, Any]:
+ """
+ Send account verification email
+ """
+ verification_url = f"{FRONTEND_URL}/verify?token={token}"
+
+ subject = "Verify Your Account"
+ html_body = f"""
+
+
+
+ Account Verification
+ Hello {username},
+ Thank you for signing up. Please click the link below to verify
+ your account:
+ Verify My Account
+ This link will expire in 24 hours.
+ If you did not create this account, please ignore this email.
+
+
+ """
+ text_body = f"""
+ Account Verification
+
+ Hello {username},
+
+ Thank you for signing up. Please use the link below to verify your account:
+
+ {verification_url}
+
+ This link will expire in 24 hours.
+
+ If you did not create this account, please ignore this email.
+ """
+
+ return await self._send_email(email, subject, html_body, text_body)
+
+ async def send_password_reset_email(
+ self, email: str, username: str, token: str
+ ) -> Dict[str, Any]:
+ """
+ Send password reset email
+ """
+ reset_url = f"{FRONTEND_URL}/reset-password?token={token}"
+
+ subject = "Password Reset Request"
+ html_body = f"""
+
+
+
+ Password Reset
+ Hello {username},
+ We received a request to reset your password. Please click the link below
+ to set a new password:
+ Reset My Password
+ This link will expire in 30 minutes.
+ If you did not request a password reset, please ignore this email.
+
+
+ """
+ text_body = f"""
+ Password Reset
+
+ Hello {username},
+
+ We received a request to reset your password. Please use the link below
+ to set a new password:
+
+ {reset_url}
+
+ This link will expire in 30 minutes.
+
+ If you did not request a password reset, please ignore this email.
+ """
+
+ return await self._send_email(email, subject, html_body, text_body)
+
+ async def _send_email(
+ self, recipient: str, subject: str, html_body: str, text_body: str
+ ) -> Dict[str, Any]:
+ """
+ Send an email using AWS SES
+ """
+ if ENV == "testing" or self.client is None:
+ logger.info(f"[MOCK EMAIL] To: {recipient}, Subject: {subject}")
+ logger.info(f"[MOCK EMAIL] Text: {text_body[:100]}...")
+ return {"MessageId": "mock-message-id"}
+
+ try:
+ response = self.client.send_email(
+ Source=self.sender_email,
+ Destination={
+ "ToAddresses": [recipient],
+ },
+ Message={
+ "Subject": {
+ "Data": subject,
+ },
+ "Body": {
+ "Text": {
+ "Data": text_body,
+ },
+ "Html": {
+ "Data": html_body,
+ },
+ },
+ },
+ )
+ logger.info(
+ f"Email sent to {recipient}! Message ID: {response.get('MessageId')}"
+ )
+ return response
+ except ClientError as e:
+ error_message = f"Error sending email to {recipient}: {e}"
+ logger.error(error_message)
+ raise HTTPException(
+ status_code=500, detail=f"Failed to send email: {str(e)}"
+ ) from e
diff --git a/backend/app/mab/routers.py b/backend/app/mab/routers.py
index 75b2cc0..86cd179 100644
--- a/backend/app/mab/routers.py
+++ b/backend/app/mab/routers.py
@@ -4,7 +4,7 @@
from fastapi.exceptions import HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
-from ..auth.dependencies import authenticate_key, get_current_user
+from ..auth.dependencies import authenticate_key, get_verified_user
from ..database import get_async_session
from ..models import get_notifications_from_db, save_notifications_to_db
from ..schemas import NotificationsResponse, Outcome, RewardLikelihood
@@ -33,7 +33,7 @@
@router.post("/", response_model=MultiArmedBanditResponse)
async def create_mab(
experiment: MultiArmedBandit,
- user_db: Annotated[UserDB, Depends(get_current_user)],
+ user_db: Annotated[UserDB, Depends(get_verified_user)],
asession: AsyncSession = Depends(get_async_session),
) -> MultiArmedBanditResponse:
"""
@@ -55,7 +55,7 @@ async def create_mab(
@router.get("/", response_model=list[MultiArmedBanditResponse])
async def get_mabs(
- user_db: Annotated[UserDB, Depends(get_current_user)],
+ user_db: Annotated[UserDB, Depends(get_verified_user)],
asession: AsyncSession = Depends(get_async_session),
) -> list[MultiArmedBanditResponse]:
"""
@@ -88,7 +88,7 @@ async def get_mabs(
@router.get("/{experiment_id}", response_model=MultiArmedBanditResponse)
async def get_mab(
experiment_id: int,
- user_db: Annotated[UserDB, Depends(get_current_user)],
+ user_db: Annotated[UserDB, Depends(get_verified_user)],
asession: AsyncSession = Depends(get_async_session),
) -> MultiArmedBanditResponse:
"""
@@ -115,7 +115,7 @@ async def get_mab(
@router.delete("/{experiment_id}", response_model=dict)
async def delete_mab(
experiment_id: int,
- user_db: Annotated[UserDB, Depends(get_current_user)],
+ user_db: Annotated[UserDB, Depends(get_verified_user)],
asession: AsyncSession = Depends(get_async_session),
) -> dict:
"""
diff --git a/backend/app/messages/routers.py b/backend/app/messages/routers.py
index 31eff9c..fea7bbe 100644
--- a/backend/app/messages/routers.py
+++ b/backend/app/messages/routers.py
@@ -3,7 +3,7 @@
from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
-from ..auth.dependencies import get_current_user
+from ..auth.dependencies import get_verified_user
from ..database import get_async_session
from ..users.models import UserDB
from .models import EventMessageDB, MessageDB
@@ -14,7 +14,7 @@
@router.get("/", response_model=list[MessageResponse])
async def get_messages(
- user_db: Annotated[UserDB, Depends(get_current_user)],
+ user_db: Annotated[UserDB, Depends(get_verified_user)],
asession: AsyncSession = Depends(get_async_session),
) -> list[MessageResponse]:
"""
@@ -27,7 +27,7 @@ async def get_messages(
@router.post("/", response_model=MessageResponse)
async def create_message(
message: EventMessageCreate,
- user_db: Annotated[UserDB, Depends(get_current_user)],
+ user_db: Annotated[UserDB, Depends(get_verified_user)],
asession: AsyncSession = Depends(get_async_session),
) -> MessageResponse:
"""
@@ -47,7 +47,7 @@ async def create_message(
@router.delete("/", response_model=list[MessageResponse])
async def delete_messages(
message_ids: list[int],
- user_db: Annotated[UserDB, Depends(get_current_user)],
+ user_db: Annotated[UserDB, Depends(get_verified_user)],
asession: AsyncSession = Depends(get_async_session),
) -> list[MessageResponse]:
"""
@@ -63,7 +63,7 @@ async def delete_messages(
@router.patch("/", response_model=list[MessageResponse])
async def mark_messages_as_read(
message_read_toggle: MessageReadToggle,
- user_db: Annotated[UserDB, Depends(get_current_user)],
+ user_db: Annotated[UserDB, Depends(get_verified_user)],
asession: AsyncSession = Depends(get_async_session),
) -> list[MessageResponse]:
"""
diff --git a/backend/app/users/models.py b/backend/app/users/models.py
index 5f1b6d1..6500c3a 100644
--- a/backend/app/users/models.py
+++ b/backend/app/users/models.py
@@ -1,6 +1,7 @@
from datetime import datetime, timezone
from sqlalchemy import (
+ Boolean,
DateTime,
Integer,
String,
@@ -48,6 +49,11 @@ class UserDB(Base):
updated_datetime_utc: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False
)
+ access_level: Mapped[str] = mapped_column(
+ String, nullable=False, default="fullaccess"
+ )
+ is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
+ is_verified: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
def __repr__(self) -> str:
"""Pretty Print"""
@@ -58,6 +64,7 @@ async def save_user_to_db(
user: UserCreateWithPassword | UserCreate,
api_key: str,
asession: AsyncSession,
+ is_verified: bool = False,
) -> UserDB:
"""
Saves a user in the database
@@ -90,6 +97,9 @@ async def save_user_to_db(
api_key_first_characters=api_key[:5],
created_datetime_utc=datetime.now(timezone.utc),
updated_datetime_utc=datetime.now(timezone.utc),
+ is_active=True,
+ is_verified=is_verified,
+ access_level="fullaccess",
)
asession.add(user_db)
@@ -119,6 +129,57 @@ async def update_user_api_key(
return user_db
+async def update_user_verification_status(
+ user_db: UserDB,
+ is_verified: bool,
+ asession: AsyncSession,
+) -> UserDB:
+ """
+ Updates a user's verification status
+ """
+ user_db.is_verified = is_verified
+ user_db.updated_datetime_utc = datetime.now(timezone.utc)
+
+ await asession.commit()
+ await asession.refresh(user_db)
+
+ return user_db
+
+
+async def update_user_active_status(
+ user_db: UserDB,
+ is_active: bool,
+ asession: AsyncSession,
+) -> UserDB:
+ """
+ Updates a user's active status
+ """
+ user_db.is_active = is_active
+ user_db.updated_datetime_utc = datetime.now(timezone.utc)
+
+ await asession.commit()
+ await asession.refresh(user_db)
+
+ return user_db
+
+
+async def update_user_password(
+ user_db: UserDB,
+ new_password: str,
+ asession: AsyncSession,
+) -> UserDB:
+ """
+ Updates a user's password
+ """
+ user_db.hashed_password = get_password_salted_hash(new_password)
+ user_db.updated_datetime_utc = datetime.now(timezone.utc)
+
+ await asession.commit()
+ await asession.refresh(user_db)
+
+ return user_db
+
+
async def get_user_by_username(
username: str,
asession: AsyncSession,
diff --git a/backend/app/users/routers.py b/backend/app/users/routers.py
index 435109a..4bf80dd 100644
--- a/backend/app/users/routers.py
+++ b/backend/app/users/routers.py
@@ -1,13 +1,15 @@
from typing import Annotated
-from fastapi import APIRouter, Depends
-from fastapi.exceptions import HTTPException
+from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from fastapi.requests import Request
+from redis.asyncio import Redis
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
-from ..auth.dependencies import get_current_user
-from ..database import get_async_session
+from ..auth.dependencies import get_current_user, get_verified_user
+from ..auth.utils import generate_verification_token
+from ..database import get_async_session, get_redis
+from ..email import EmailService
from ..users.models import (
UserAlreadyExistsError,
UserDB,
@@ -17,21 +19,25 @@
from ..utils import generate_key, setup_logger, update_api_limits
from .schemas import KeyResponse, UserCreate, UserCreateWithPassword, UserRetrieve
+# Router setup
TAG_METADATA = {
"name": "Admin",
"description": "_Requires user login._ Only administrator user has access to these "
"endpoints.",
}
-router = APIRouter(prefix="/user", tags=["Admin"])
+router = APIRouter(prefix="/user", tags=["Users"])
logger = setup_logger()
+email_service = EmailService()
@router.post("/", response_model=UserCreate)
async def create_user(
user: UserCreateWithPassword,
request: Request,
+ background_tasks: BackgroundTasks,
asession: AsyncSession = Depends(get_async_session),
+ redis: Redis = Depends(get_redis),
) -> UserCreate | None:
"""
Create user endpoint.
@@ -43,9 +49,19 @@ async def create_user(
user=user,
api_key=new_api_key,
asession=asession,
+ is_verified=False,
)
- await update_api_limits(
- request.app.state.redis, user_new.username, user_new.api_daily_quota
+ await update_api_limits(redis, user_new.username, user_new.api_daily_quota)
+
+ token = await generate_verification_token(
+ user_new.user_id, user_new.username, redis
+ )
+
+ background_tasks.add_task(
+ email_service.send_verification_email,
+ user_new.username,
+ user_new.username,
+ token,
)
return UserCreate(
@@ -76,12 +92,15 @@ async def get_user(
api_key_updated_datetime_utc=user_db.api_key_updated_datetime_utc,
created_datetime_utc=user_db.created_datetime_utc,
updated_datetime_utc=user_db.updated_datetime_utc,
+ is_active=user_db.is_active,
+ is_verified=user_db.is_verified,
+ access_level=user_db.access_level,
)
@router.put("/rotate-key", response_model=KeyResponse)
async def get_new_api_key(
- user_db: Annotated[UserDB, Depends(get_current_user)],
+ user_db: Annotated[UserDB, Depends(get_verified_user)],
asession: AsyncSession = Depends(get_async_session),
) -> KeyResponse | None:
"""
diff --git a/backend/app/users/schemas.py b/backend/app/users/schemas.py
index d773de3..5bb3294 100644
--- a/backend/app/users/schemas.py
+++ b/backend/app/users/schemas.py
@@ -1,10 +1,9 @@
from datetime import datetime
from typing import Optional
-from pydantic import BaseModel, ConfigDict
+from pydantic import BaseModel, ConfigDict, EmailStr
-# not yet used.
class UserCreate(BaseModel):
"""
Pydantic model for user creation
@@ -19,7 +18,7 @@ class UserCreate(BaseModel):
class UserCreateWithPassword(UserCreate):
"""
- Pydantic model for user creation
+ Pydantic model for user creation with password.
"""
password: str
@@ -39,6 +38,9 @@ class UserRetrieve(BaseModel):
api_key_updated_datetime_utc: datetime
created_datetime_utc: datetime
updated_datetime_utc: datetime
+ is_active: bool
+ is_verified: bool
+ access_level: str
model_config = ConfigDict(from_attributes=True)
@@ -51,3 +53,40 @@ class KeyResponse(BaseModel):
username: str
new_api_key: str
model_config = ConfigDict(from_attributes=True)
+
+
+class PasswordResetRequest(BaseModel):
+ """
+ Pydantic model for password reset request
+ """
+
+ username: EmailStr
+ model_config = ConfigDict(from_attributes=True)
+
+
+class PasswordResetConfirm(BaseModel):
+ """
+ Pydantic model for password reset confirmation
+ """
+
+ token: str
+ new_password: str
+ model_config = ConfigDict(from_attributes=True)
+
+
+class EmailVerificationRequest(BaseModel):
+ """
+ Pydantic model for email verification
+ """
+
+ token: str
+ model_config = ConfigDict(from_attributes=True)
+
+
+class MessageResponse(BaseModel):
+ """
+ Pydantic model for generic message responses
+ """
+
+ message: str
+ model_config = ConfigDict(from_attributes=True)
diff --git a/backend/migrations/versions/97137b6afb58_added_user_fields.py b/backend/migrations/versions/97137b6afb58_added_user_fields.py
new file mode 100644
index 0000000..5743a04
--- /dev/null
+++ b/backend/migrations/versions/97137b6afb58_added_user_fields.py
@@ -0,0 +1,45 @@
+"""Added user fields
+
+Revision ID: 97137b6afb58
+Revises: d95b5c0590c3
+Create Date: 2025-04-02 21:09:56.417771
+
+"""
+
+from typing import Sequence, Union
+
+import sqlalchemy as sa
+from alembic import op
+
+# revision identifiers, used by Alembic.
+revision: str = "97137b6afb58"
+down_revision: Union[str, None] = "d95b5c0590c3"
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.add_column(
+ "users",
+ sa.Column(
+ "access_level", sa.String(), nullable=False, server_default="fullaccess"
+ ),
+ )
+ op.add_column(
+ "users",
+ sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
+ )
+ op.add_column(
+ "users",
+ sa.Column("is_verified", sa.Boolean(), nullable=False, server_default="false"),
+ )
+ # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_column("users", "is_verified")
+ op.drop_column("users", "is_active")
+ op.drop_column("users", "access_level")
+ # ### end Alembic commands ###
diff --git a/backend/requirements.txt b/backend/requirements.txt
index 27b026d..18633d3 100644
--- a/backend/requirements.txt
+++ b/backend/requirements.txt
@@ -13,3 +13,5 @@ scikit-learn==1.6.1
scipy==1.15.2
sqlalchemy[asyncio]==2.0.20
uvicorn==0.23.2
+boto3==1.37.25
+pydantic[email]
diff --git a/backend/tests/test.env b/backend/tests/test.env
index 648493b..a222901 100644
--- a/backend/tests/test.env
+++ b/backend/tests/test.env
@@ -1,3 +1,4 @@
+ENV=testing
PROMETHEUS_MULTIPROC_DIR=/tmp
# DB connection
POSTGRES_USER=postgres-test-user
@@ -10,3 +11,10 @@ REDIS_HOST=redis://localhost:6381
ADMIN_USERNAME=test@idinsight.org
ADMIN_PASSWORD=test123
ADMIN_API_KEY=testkey123
+
+NEXT_PUBLIC_GOOGLE_LOGIN_CLIENT_ID=random_key
+SES_REGION=ap-south-1
+SES_SENDER_EMAIL=no-reply@example.com
+FRONTEND_URL=http://localhost:3000
+AWS_ACCESS_KEY_ID=aws_access_key
+AWS_SECRET_ACCESS_KEY=aws_secret_key
diff --git a/frontend/package-lock.json b/frontend/package-lock.json
index c0e85ec..2d97cb8 100644
--- a/frontend/package-lock.json
+++ b/frontend/package-lock.json
@@ -408,6 +408,134 @@
"node": ">= 10"
}
},
+ "node_modules/@next/swc-darwin-x64": {
+ "version": "14.2.25",
+ "resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-14.2.25.tgz",
+ "integrity": "sha512-V+iYM/QR+aYeJl3/FWWU/7Ix4b07ovsQ5IbkwgUK29pTHmq+5UxeDr7/dphvtXEq5pLB/PucfcBNh9KZ8vWbug==",
+ "cpu": [
+ "x64"
+ ],
+ "license": "MIT",
+ "optional": true,
+ "os": [
+ "darwin"
+ ],
+ "engines": {
+ "node": ">= 10"
+ }
+ },
+ "node_modules/@next/swc-linux-arm64-gnu": {
+ "version": "14.2.25",
+ "resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-14.2.25.tgz",
+ "integrity": "sha512-LFnV2899PJZAIEHQ4IMmZIgL0FBieh5keMnriMY1cK7ompR+JUd24xeTtKkcaw8QmxmEdhoE5Mu9dPSuDBgtTg==",
+ "cpu": [
+ "arm64"
+ ],
+ "license": "MIT",
+ "optional": true,
+ "os": [
+ "linux"
+ ],
+ "engines": {
+ "node": ">= 10"
+ }
+ },
+ "node_modules/@next/swc-linux-arm64-musl": {
+ "version": "14.2.25",
+ "resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-14.2.25.tgz",
+ "integrity": "sha512-QC5y5PPTmtqFExcKWKYgUNkHeHE/z3lUsu83di488nyP0ZzQ3Yse2G6TCxz6nNsQwgAx1BehAJTZez+UQxzLfw==",
+ "cpu": [
+ "arm64"
+ ],
+ "license": "MIT",
+ "optional": true,
+ "os": [
+ "linux"
+ ],
+ "engines": {
+ "node": ">= 10"
+ }
+ },
+ "node_modules/@next/swc-linux-x64-gnu": {
+ "version": "14.2.25",
+ "resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-14.2.25.tgz",
+ "integrity": "sha512-y6/ML4b9eQ2D/56wqatTJN5/JR8/xdObU2Fb1RBidnrr450HLCKr6IJZbPqbv7NXmje61UyxjF5kvSajvjye5w==",
+ "cpu": [
+ "x64"
+ ],
+ "license": "MIT",
+ "optional": true,
+ "os": [
+ "linux"
+ ],
+ "engines": {
+ "node": ">= 10"
+ }
+ },
+ "node_modules/@next/swc-linux-x64-musl": {
+ "version": "14.2.25",
+ "resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-14.2.25.tgz",
+ "integrity": "sha512-sPX0TSXHGUOZFvv96GoBXpB3w4emMqKeMgemrSxI7A6l55VBJp/RKYLwZIB9JxSqYPApqiREaIIap+wWq0RU8w==",
+ "cpu": [
+ "x64"
+ ],
+ "license": "MIT",
+ "optional": true,
+ "os": [
+ "linux"
+ ],
+ "engines": {
+ "node": ">= 10"
+ }
+ },
+ "node_modules/@next/swc-win32-arm64-msvc": {
+ "version": "14.2.25",
+ "resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-14.2.25.tgz",
+ "integrity": "sha512-ReO9S5hkA1DU2cFCsGoOEp7WJkhFzNbU/3VUF6XxNGUCQChyug6hZdYL/istQgfT/GWE6PNIg9cm784OI4ddxQ==",
+ "cpu": [
+ "arm64"
+ ],
+ "license": "MIT",
+ "optional": true,
+ "os": [
+ "win32"
+ ],
+ "engines": {
+ "node": ">= 10"
+ }
+ },
+ "node_modules/@next/swc-win32-ia32-msvc": {
+ "version": "14.2.25",
+ "resolved": "https://registry.npmjs.org/@next/swc-win32-ia32-msvc/-/swc-win32-ia32-msvc-14.2.25.tgz",
+ "integrity": "sha512-DZ/gc0o9neuCDyD5IumyTGHVun2dCox5TfPQI/BJTYwpSNYM3CZDI4i6TOdjeq1JMo+Ug4kPSMuZdwsycwFbAw==",
+ "cpu": [
+ "ia32"
+ ],
+ "license": "MIT",
+ "optional": true,
+ "os": [
+ "win32"
+ ],
+ "engines": {
+ "node": ">= 10"
+ }
+ },
+ "node_modules/@next/swc-win32-x64-msvc": {
+ "version": "14.2.25",
+ "resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-14.2.25.tgz",
+ "integrity": "sha512-KSznmS6eFjQ9RJ1nEc66kJvtGIL1iZMYmGEXsZPh2YtnLtqrgdVvKXJY2ScjjoFnG6nGLyPFR0UiEvDwVah4Tw==",
+ "cpu": [
+ "x64"
+ ],
+ "license": "MIT",
+ "optional": true,
+ "os": [
+ "win32"
+ ],
+ "engines": {
+ "node": ">= 10"
+ }
+ },
"node_modules/@nodelib/fs.scandir": {
"version": "2.1.5",
"resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz",
diff --git a/frontend/src/app/forgot-password/page.tsx b/frontend/src/app/forgot-password/page.tsx
new file mode 100644
index 0000000..3f048cf
--- /dev/null
+++ b/frontend/src/app/forgot-password/page.tsx
@@ -0,0 +1,163 @@
+"use client";
+
+import { motion } from "framer-motion";
+import Link from "next/link";
+import { zodResolver } from "@hookform/resolvers/zod";
+import { useForm } from "react-hook-form";
+import * as z from "zod";
+import { Button } from "@/components/ui/button";
+import {
+ Card,
+ CardContent,
+ CardDescription,
+ CardFooter,
+ CardHeader,
+ CardTitle,
+} from "@/components/ui/card";
+import {
+ Form,
+ FormControl,
+ FormField,
+ FormItem,
+ FormLabel,
+ FormMessage,
+} from "@/components/ui/form";
+import { Input } from "@/components/ui/input";
+import { useState } from "react";
+import { apiCalls } from "@/utils/api";
+
+const formSchema = z.object({
+ email: z.string().email({
+ message: "Please enter a valid email address.",
+ }),
+});
+
+export default function ForgotPasswordPage() {
+ const form = useForm>({
+ resolver: zodResolver(formSchema),
+ defaultValues: {
+ email: "",
+ },
+ });
+
+ const [isSubmitting, setIsSubmitting] = useState(false);
+ const [success, setSuccess] = useState(false);
+ const [errorState, setErrorState] = useState(null);
+
+ async function onSubmit(values: z.infer) {
+ setIsSubmitting(true);
+ setErrorState(null);
+
+ try {
+ const response = await apiCalls.requestPasswordReset(values.email);
+ console.log("Password reset request response:", response);
+ setSuccess(true);
+ } catch (error) {
+ setErrorState("An error occurred while sending the reset link. Please try again later.");
+ console.error("Password reset request error:", error);
+ } finally {
+ setIsSubmitting(false);
+ }
+ }
+
+ return (
+
+
+
+
+
+ Forgot Password
+
+
+ Enter your email to receive a password reset link
+
+
+
+ {success ? (
+
+
+
+
+
+
+ If an account with this email exists, a password reset link has been sent.
+
+
+
+
+
+ Return to Login
+
+
+ ) : (
+
+
+ )}
+
+
+
+ {"Remember your password? "}
+
+ Sign in
+
+
+
+
+
+
+ );
+}
diff --git a/frontend/src/app/login/page.tsx b/frontend/src/app/login/page.tsx
index bb7c1df..893eb28 100644
--- a/frontend/src/app/login/page.tsx
+++ b/frontend/src/app/login/page.tsx
@@ -118,23 +118,31 @@ export default function LoginPage() {
)}
/>
- (
-
-
-
-
-
- Remember me
-
-
- )}
- />
+
+
(
+
+
+
+
+
+ Remember me
+
+
+ )}
+ />
+
+ Forgot password?
+
+
Sign in
diff --git a/frontend/src/app/reset-password/page.tsx b/frontend/src/app/reset-password/page.tsx
new file mode 100644
index 0000000..c589ac1
--- /dev/null
+++ b/frontend/src/app/reset-password/page.tsx
@@ -0,0 +1,245 @@
+"use client";
+
+import { motion } from "framer-motion";
+import Link from "next/link";
+import { zodResolver } from "@hookform/resolvers/zod";
+import { useForm } from "react-hook-form";
+import * as z from "zod";
+import { Button } from "@/components/ui/button";
+import {
+ Card,
+ CardContent,
+ CardDescription,
+ CardFooter,
+ CardHeader,
+ CardTitle,
+} from "@/components/ui/card";
+import {
+ Form,
+ FormControl,
+ FormField,
+ FormItem,
+ FormLabel,
+ FormMessage,
+} from "@/components/ui/form";
+import { Input } from "@/components/ui/input";
+import { useEffect, useState } from "react";
+import { apiCalls } from "@/utils/api";
+import { useRouter, useSearchParams } from "next/navigation";
+
+const formSchema = z
+ .object({
+ password: z.string().min(4, {
+ message: "Password must be at least 4 characters long.",
+ }),
+ confirm_password: z.string().min(4, {
+ message: "Password must be at least 4 characters long.",
+ }),
+ })
+ .refine((data) => data.password === data.confirm_password, {
+ message: "Passwords do not match",
+ path: ["confirm_password"],
+ });
+
+export default function ResetPasswordPage() {
+ const router = useRouter();
+ const searchParams = useSearchParams();
+
+ const [token, setToken] = useState(null);
+ const [isSubmitting, setIsSubmitting] = useState(false);
+ const [success, setSuccess] = useState(false);
+ const [tokenError, setTokenError] = useState(false);
+ const [errorState, setErrorState] = useState(null);
+
+ const form = useForm>({
+ resolver: zodResolver(formSchema),
+ defaultValues: {
+ password: "",
+ confirm_password: "",
+ },
+ });
+
+ useEffect(() => {
+ const tokenParam = searchParams?.get("token");
+ if (!tokenParam) {
+ setTokenError(true);
+ } else {
+ setToken(tokenParam);
+ }
+ }, [searchParams]);
+
+ async function onSubmit(values: z.infer) {
+ if (!token) {
+ setTokenError(true);
+ return;
+ }
+
+ setIsSubmitting(true);
+ setErrorState(null);
+
+ try {
+ const response = await apiCalls.resetPassword(token, values.password);
+ setSuccess(true);
+
+ // Redirect to login page after 3 seconds
+ setTimeout(() => {
+ router.push("/login");
+ }, 3000);
+ } catch (error) {
+ setErrorState("Failed to reset password. The link may have expired or is invalid.");
+ console.error("Password reset error:", error);
+ } finally {
+ setIsSubmitting(false);
+ }
+ }
+
+ return (
+
+
+
+
+
+ Reset Password
+
+
+ Enter your new password
+
+
+
+ {tokenError ? (
+
+
+
+
+
+
+ Missing or invalid reset token. Please request a new password reset link.
+
+
+
+
+
+ Request New Reset Link
+
+
+ ) : success ? (
+
+
+
+
+
+
+ Your password has been reset successfully. Redirecting to login...
+
+
+
+
+
+ ) : (
+
+
+ )}
+
+
+
+ {"Remember your password? "}
+
+ Sign in
+
+
+
+
+
+
+ );
+}
diff --git a/frontend/src/app/verification-required/page.tsx b/frontend/src/app/verification-required/page.tsx
new file mode 100644
index 0000000..b7b0955
--- /dev/null
+++ b/frontend/src/app/verification-required/page.tsx
@@ -0,0 +1,171 @@
+"use client";
+
+import { motion } from "framer-motion";
+import Link from "next/link";
+import { Button } from "@/components/ui/button";
+import {
+ Card,
+ CardContent,
+ CardDescription,
+ CardFooter,
+ CardHeader,
+ CardTitle,
+} from "@/components/ui/card";
+import { useState } from "react";
+import { apiCalls } from "@/utils/api";
+import { useAuth } from "@/utils/auth";
+
+export default function VerificationRequiredPage() {
+ const { user, logout } = useAuth();
+ const [isResending, setIsResending] = useState(false);
+ const [resendSuccess, setResendSuccess] = useState(false);
+ const [resendError, setResendError] = useState(null);
+
+ const handleResendVerification = async () => {
+ if (!user) {
+ setResendError("User information not available. Please try logging in again.");
+ return;
+ }
+
+ setIsResending(true);
+ setResendError(null);
+
+ try {
+ await apiCalls.resendVerification(user);
+ setResendSuccess(true);
+ } catch (error) {
+ setResendError("Failed to resend verification email. Please try again later.");
+ console.error("Resend verification error:", error);
+ } finally {
+ setIsResending(false);
+ }
+ };
+
+ return (
+
+
+
+
+
+ Email Verification Required
+
+
+ Please verify your email address to continue
+
+
+
+
+
+
+
+
+ Your account requires email verification. We've sent a verification link to your email address.
+
+
+
+
+
+ {resendSuccess && (
+
+
+
+
+
+ Verification email has been resent. Please check your inbox.
+
+
+
+
+ )}
+
+ {resendError && (
+
+ )}
+
+
+
+ {isResending
+ ? "Sending..."
+ : resendSuccess
+ ? "Email Sent"
+ : "Resend Verification Email"}
+
+
+ Logout
+
+
+
+
+
+ Already verified? Try{" "}
+
+ logging in again
+
+
+
+
+
+
+ );
+}
diff --git a/frontend/src/app/verify/page.tsx b/frontend/src/app/verify/page.tsx
new file mode 100644
index 0000000..2273ba1
--- /dev/null
+++ b/frontend/src/app/verify/page.tsx
@@ -0,0 +1,159 @@
+"use client";
+
+import { motion } from "framer-motion";
+import Link from "next/link";
+import { Button } from "@/components/ui/button";
+import {
+ Card,
+ CardContent,
+ CardDescription,
+ CardFooter,
+ CardHeader,
+ CardTitle,
+} from "@/components/ui/card";
+import { useEffect, useState } from "react";
+import { apiCalls } from "@/utils/api";
+import { useRouter, useSearchParams } from "next/navigation";
+
+export default function VerifyEmailPage() {
+ const router = useRouter();
+ const searchParams = useSearchParams();
+
+ const [token, setToken] = useState(null);
+ const [status, setStatus] = useState<'loading' | 'success' | 'error'>('loading');
+ const [errorMessage, setErrorMessage] = useState('');
+
+ useEffect(() => {
+ const tokenParam = searchParams?.get("token");
+ if (!tokenParam) {
+ setStatus('error');
+ setErrorMessage('Missing verification token.');
+ return;
+ }
+
+ setToken(tokenParam);
+ verifyEmail(tokenParam);
+ }, [searchParams]);
+
+ const verifyEmail = async (token: string) => {
+ try {
+ const response = await apiCalls.verifyEmail(token);
+ setStatus('success');
+
+ // Redirect to login page after 3 seconds
+ setTimeout(() => {
+ router.push("/login");
+ }, 3000);
+ } catch (error) {
+ setStatus('error');
+ setErrorMessage('Failed to verify email. The link may have expired or is invalid.');
+ console.error("Email verification error:", error);
+ }
+ };
+
+ const handleResendVerification = () => {
+ router.push("/resend-verification");
+ };
+
+ return (
+
+
+
+
+
+ Email Verification
+
+
+ {status === 'loading' ? 'Verifying your email address...' :
+ status === 'success' ? 'Email verified successfully!' :
+ 'Email verification failed'}
+
+
+
+ {status === 'loading' && (
+
+ )}
+
+ {status === 'success' && (
+
+
+
+
+
+
+ Your email has been verified successfully! You can now log in.
+
+
+
+
+
Redirecting to login page...
+
+ )}
+
+ {status === 'error' && (
+
+
+
+ Request New Verification Link
+
+
+ )}
+
+
+
+ {"Return to "}
+
+ Sign in
+
+
+
+
+
+
+ );
+}
diff --git a/frontend/src/components/ProtectedComponent.tsx b/frontend/src/components/ProtectedComponent.tsx
index 9cc3cd9..dc7716e 100644
--- a/frontend/src/components/ProtectedComponent.tsx
+++ b/frontend/src/components/ProtectedComponent.tsx
@@ -5,27 +5,35 @@ import { usePathname, useRouter } from "next/navigation";
interface ProtectedComponentProps {
children: React.ReactNode;
+ requireVerified?: boolean;
}
const ProtectedComponent: React.FC = ({
children,
+ requireVerified = true,
}) => {
const router = useRouter();
- const { token } = useAuth();
+ const { token, isVerified } = useAuth();
const pathname = usePathname();
const [isClient, setIsClient] = useState(false);
useEffect(() => {
if (!token) {
router.push("/login?sourcePage=" + encodeURIComponent(pathname));
+ return;
}
- }, [token, pathname, router]);
+
+ if (requireVerified && !isVerified) {
+ router.push("/verification-required");
+ }
+ }, [token, isVerified, requireVerified, pathname, router]);
// This is to prevent the page from starting to load the children before the token is checked
useEffect(() => {
setIsClient(true);
}, []);
- if (!token || !isClient) {
+
+ if (!token || (requireVerified && !isVerified) || !isClient) {
return null;
} else {
return <>{children}>;
diff --git a/frontend/src/utils/api.ts b/frontend/src/utils/api.ts
index 86f8726..9a60457 100644
--- a/frontend/src/utils/api.ts
+++ b/frontend/src/utils/api.ts
@@ -91,10 +91,53 @@ const registerUser = async (username: string, password: string) => {
}
};
+const requestPasswordReset = async (username: string) => {
+ try {
+ const response = await api.post("/request-password-reset", { username });
+ return response.data;
+ } catch (error) {
+ throw new Error("Error requesting password reset");
+ }
+};
+
+const resetPassword = async (token: string, newPassword: string) => {
+ try {
+ const response = await api.post("/reset-password", {
+ token,
+ new_password: newPassword
+ });
+ return response.data;
+ } catch (error) {
+ throw new Error("Error resetting password");
+ }
+};
+
+const verifyEmail = async (token: string) => {
+ try {
+ const response = await api.post("/verify-email", { token });
+ return response.data;
+ } catch (error) {
+ throw new Error("Error verifying email");
+ }
+};
+
+const resendVerification = async (username: string) => {
+ try {
+ const response = await api.post("/resend-verification", { username });
+ return response.data;
+ } catch (error) {
+ throw new Error("Error resending verification email");
+ }
+};
+
export const apiCalls = {
getUser,
getLoginToken,
getGoogleLoginToken,
registerUser,
+ requestPasswordReset,
+ resetPassword,
+ verifyEmail,
+ resendVerification,
};
export default api;
diff --git a/frontend/src/utils/auth.tsx b/frontend/src/utils/auth.tsx
index 015b122..805a3cb 100644
--- a/frontend/src/utils/auth.tsx
+++ b/frontend/src/utils/auth.tsx
@@ -1,11 +1,12 @@
"use client";
import { apiCalls } from "@/utils/api";
import { useRouter, useSearchParams } from "next/navigation";
-import { ReactNode, createContext, useContext, useState } from "react";
+import { ReactNode, createContext, useContext, useState, useEffect } from "react";
type AuthContextType = {
token: string | null;
user: string | null;
+ isVerified: boolean;
login: (username: string, password: string) => void;
logout: () => void;
loginError: string | null;
@@ -25,7 +26,6 @@ type AuthProviderProps = {
};
const AuthProvider = ({ children }: AuthProviderProps) => {
-
const getInitialToken = () => {
if (typeof window !== "undefined") {
return localStorage.getItem("ee-token");
@@ -42,25 +42,64 @@ const AuthProvider = ({ children }: AuthProviderProps) => {
const [user, setUser] = useState(getInitialUsername);
const [token, setToken] = useState(getInitialToken);
+ const [isVerified, setIsVerified] = useState(false);
const [loginError, setLoginError] = useState(null);
const searchParams = useSearchParams();
const router = useRouter();
+ // Check verification status on init if token exists
+ useEffect(() => {
+ const checkVerificationStatus = async () => {
+ const currentToken = getInitialToken();
+ if (currentToken) {
+ try {
+ const userData = await apiCalls.getUser(currentToken);
+ setIsVerified(userData.is_verified);
+ } catch (error) {
+ console.error("Error fetching user verification status:", error);
+ }
+ }
+ };
+
+ checkVerificationStatus();
+ }, []);
+
const login = async (username: string, password: string) => {
const sourcePage = searchParams.has("sourcePage")
? decodeURIComponent(searchParams.get("sourcePage") as string)
: "/";
try {
- const { access_token } = await apiCalls.getLoginToken(username, password);
+ const response = await apiCalls.getLoginToken(username, password);
+ const { access_token } = response;
+
localStorage.setItem("ee-token", access_token);
localStorage.setItem("ee-username", username);
setUser(username);
setToken(access_token);
setLoginError(null);
- router.push(sourcePage);
+
+ // Check if verification status is in the response
+ if (response.is_verified !== undefined) {
+ setIsVerified(response.is_verified);
+ } else {
+ // If not in response, fetch user data to get verification status
+ try {
+ const userData = await apiCalls.getUser(access_token);
+ setIsVerified(userData.is_verified);
+ } catch (error) {
+ console.error("Error fetching user verification status:", error);
+ }
+ }
+
+ // Redirect to verification page if not verified, otherwise to original destination
+ if (response.is_verified === false) {
+ router.push("/verification-required");
+ } else {
+ router.push(sourcePage);
+ }
} catch (error: unknown) {
if (
error &&
@@ -86,27 +125,27 @@ const AuthProvider = ({ children }: AuthProviderProps) => {
? decodeURIComponent(searchParams.get("sourcePage") as string)
: "/";
- apiCalls
- .getGoogleLoginToken({ client_id: client_id, credential: credential })
- .then(
- ({
- access_token,
- username,
- }: {
- access_token: string;
- username: string;
- }) => {
- localStorage.setItem("ee-token", access_token);
- localStorage.setItem("ee-username", username);
- setUser(username);
- setToken(access_token);
- router.push(sourcePage);
- }
- )
- .catch((error: Error) => {
- setLoginError("Invalid Google credentials");
- console.error("Google login error:", error);
+ try {
+ const response = await apiCalls.getGoogleLoginToken({
+ client_id: client_id,
+ credential: credential
});
+
+ const { access_token, username } = response;
+
+ localStorage.setItem("ee-token", access_token);
+ localStorage.setItem("ee-username", username);
+
+ setUser(username);
+ setToken(access_token);
+
+ setIsVerified(true);
+
+ router.push(sourcePage);
+ } catch (error) {
+ setLoginError("Invalid Google credentials");
+ console.error("Google login error:", error);
+ }
};
const logout = () => {
@@ -115,16 +154,18 @@ const AuthProvider = ({ children }: AuthProviderProps) => {
localStorage.removeItem("ee-username");
setUser(null);
setToken(null);
+ setIsVerified(false);
router.push("/login");
};
const authValue: AuthContextType = {
- token: token,
- user: user,
- login: login,
- loginError: loginError,
- loginGoogle: loginGoogle,
- logout: logout,
+ token,
+ user,
+ isVerified,
+ login,
+ loginError,
+ loginGoogle,
+ logout,
};
return (
@@ -141,3 +182,16 @@ export const useAuth = () => {
}
return context;
};
+
+export const useRequireVerified = () => {
+ const { token, isVerified } = useAuth();
+ const router = useRouter();
+
+ useEffect(() => {
+ if (token && !isVerified) {
+ router.push("/verification-required");
+ }
+ }, [token, isVerified, router]);
+
+ return { token, isVerified };
+};
From 3e862414bedd455cac2c8fba644451ea930df25b Mon Sep 17 00:00:00 2001
From: Jay Prakash <0freerunning@gmail.com>
Date: Thu, 3 Apr 2025 04:01:47 +0530
Subject: [PATCH 02/74] Added env variables for test
---
.github/workflows/unit_tests.yaml | 6 ++++++
1 file changed, 6 insertions(+)
diff --git a/.github/workflows/unit_tests.yaml b/.github/workflows/unit_tests.yaml
index 0c4ffd5..c9286b3 100644
--- a/.github/workflows/unit_tests.yaml
+++ b/.github/workflows/unit_tests.yaml
@@ -15,6 +15,12 @@ env:
ADMIN_USERNAME: test@idinsight.org
ADMIN_PASSWORD: test123
ADMIN_API_KEY: testkey123
+ SES_REGION: ap-south-1
+ SES_SENDER_EMAIL: no-reply@example.com
+ FRONTEND_URL: http://localhost:3000
+ AWS_ACCESS_KEY_ID: aws_access_key
+ AWS_SECRET_ACCESS_KEY: aws_secret_key
+ ENV: testing
jobs:
container-job:
runs-on: ubuntu-22.04
From b183861e5036c8bdb4195e30b5b43e06c280389a Mon Sep 17 00:00:00 2001
From: Jay Prakash <0freerunning@gmail.com>
Date: Fri, 4 Apr 2025 04:14:05 +0530
Subject: [PATCH 03/74] Fix redirect issue
---
.../src/components/ProtectedComponent.tsx | 29 ++++++++++---------
frontend/src/utils/auth.tsx | 21 +++++++++++---
2 files changed, 32 insertions(+), 18 deletions(-)
diff --git a/frontend/src/components/ProtectedComponent.tsx b/frontend/src/components/ProtectedComponent.tsx
index dc7716e..f15a761 100644
--- a/frontend/src/components/ProtectedComponent.tsx
+++ b/frontend/src/components/ProtectedComponent.tsx
@@ -13,27 +13,28 @@ const ProtectedComponent: React.FC = ({
requireVerified = true,
}) => {
const router = useRouter();
- const { token, isVerified } = useAuth();
+ const { token, isVerified, isLoading } = useAuth();
const pathname = usePathname();
const [isClient, setIsClient] = useState(false);
- useEffect(() => {
- if (!token) {
- router.push("/login?sourcePage=" + encodeURIComponent(pathname));
- return;
- }
-
- if (requireVerified && !isVerified) {
- router.push("/verification-required");
- }
- }, [token, isVerified, requireVerified, pathname, router]);
-
- // This is to prevent the page from starting to load the children before the token is checked
useEffect(() => {
setIsClient(true);
}, []);
- if (!token || (requireVerified && !isVerified) || !isClient) {
+ useEffect(() => {
+ if (isClient && !isLoading) {
+ if (!token) {
+ router.push("/login?sourcePage=" + encodeURIComponent(pathname));
+ return;
+ }
+
+ if (requireVerified && !isVerified) {
+ router.push("/verification-required");
+ }
+ }
+ }, [token, isVerified, isLoading, requireVerified, pathname, router, isClient]);
+
+ if (!isClient || isLoading || !token || (requireVerified && !isVerified)) {
return null;
} else {
return <>{children}>;
diff --git a/frontend/src/utils/auth.tsx b/frontend/src/utils/auth.tsx
index 805a3cb..ee07bad 100644
--- a/frontend/src/utils/auth.tsx
+++ b/frontend/src/utils/auth.tsx
@@ -7,6 +7,7 @@ type AuthContextType = {
token: string | null;
user: string | null;
isVerified: boolean;
+ isLoading: boolean;
login: (username: string, password: string) => void;
logout: () => void;
loginError: string | null;
@@ -43,6 +44,7 @@ const AuthProvider = ({ children }: AuthProviderProps) => {
const [user, setUser] = useState(getInitialUsername);
const [token, setToken] = useState(getInitialToken);
const [isVerified, setIsVerified] = useState(false);
+ const [isLoading, setIsLoading] = useState(!!getInitialToken());
const [loginError, setLoginError] = useState(null);
const searchParams = useSearchParams();
@@ -53,11 +55,15 @@ const AuthProvider = ({ children }: AuthProviderProps) => {
const checkVerificationStatus = async () => {
const currentToken = getInitialToken();
if (currentToken) {
+ setIsLoading(true);
try {
const userData = await apiCalls.getUser(currentToken);
setIsVerified(userData.is_verified);
} catch (error) {
console.error("Error fetching user verification status:", error);
+ logout();
+ } finally {
+ setIsLoading(false);
}
}
};
@@ -71,6 +77,7 @@ const AuthProvider = ({ children }: AuthProviderProps) => {
: "/";
try {
+ setIsLoading(true);
const response = await apiCalls.getLoginToken(username, password);
const { access_token } = response;
@@ -111,6 +118,8 @@ const AuthProvider = ({ children }: AuthProviderProps) => {
} else {
setLoginError("An unexpected error occurred. Please try again later.");
}
+ } finally {
+ setIsLoading(false);
}
};
@@ -126,6 +135,7 @@ const AuthProvider = ({ children }: AuthProviderProps) => {
: "/";
try {
+ setIsLoading(true);
const response = await apiCalls.getGoogleLoginToken({
client_id: client_id,
credential: credential
@@ -145,6 +155,8 @@ const AuthProvider = ({ children }: AuthProviderProps) => {
} catch (error) {
setLoginError("Invalid Google credentials");
console.error("Google login error:", error);
+ } finally {
+ setIsLoading(false);
}
};
@@ -162,6 +174,7 @@ const AuthProvider = ({ children }: AuthProviderProps) => {
token,
user,
isVerified,
+ isLoading,
login,
loginError,
loginGoogle,
@@ -184,14 +197,14 @@ export const useAuth = () => {
};
export const useRequireVerified = () => {
- const { token, isVerified } = useAuth();
+ const { token, isVerified, isLoading } = useAuth();
const router = useRouter();
useEffect(() => {
- if (token && !isVerified) {
+ if (!isLoading && token && !isVerified) {
router.push("/verification-required");
}
- }, [token, isVerified, router]);
+ }, [token, isVerified, isLoading, router]);
- return { token, isVerified };
+ return { token, isVerified, isLoading };
};
From c79830bd71950cef942f76874304d8d9bfabf31c Mon Sep 17 00:00:00 2001
From: Jay Prakash <0freerunning@gmail.com>
Date: Tue, 22 Apr 2025 03:04:43 +0530
Subject: [PATCH 04/74] Adding the workspace feature
---
backend/app/__init__.py | 2 +
backend/app/auth/dependencies.py | 7 +-
backend/app/auth/routers.py | 47 +-
backend/app/auth/utils.py | 8 -
backend/app/contextual_mab/models.py | 20 +-
backend/app/contextual_mab/routers.py | 82 ++-
backend/app/contextual_mab/schemas.py | 1 +
backend/app/email.py | 59 +++
backend/app/mab/models.py | 19 +-
backend/app/mab/routers.py | 98 +++-
backend/app/mab/schemas.py | 1 +
backend/app/models.py | 7 +-
backend/app/users/models.py | 3 +-
backend/app/users/routers.py | 14 +-
backend/app/users/schemas.py | 3 +
backend/app/workspaces/__init__.py | 0
backend/app/workspaces/models.py | 303 +++++++++++
backend/app/workspaces/routers.py | 479 ++++++++++++++++++
backend/app/workspaces/schemas.py | 111 ++++
backend/app/workspaces/utils.py | 141 ++++++
.../949c9fc0461d_workspace_relationship.py | 32 ++
.../versions/977e7e73ce06_workspace_model.py | 56 ++
.../app/(protected)/workspace/create/page.tsx | 131 +++++
.../app/(protected)/workspace/invite/page.tsx | 193 +++++++
.../src/app/(protected)/workspace/page.tsx | 116 +++++
.../src/app/(protected)/workspace/types.ts | 51 ++
frontend/src/components/WorkspaceSelector.tsx | 142 ++++++
frontend/src/components/sidebar.tsx | 23 +-
frontend/src/utils/api.ts | 90 +++-
frontend/src/utils/auth.tsx | 82 ++-
30 files changed, 2256 insertions(+), 65 deletions(-)
create mode 100644 backend/app/workspaces/__init__.py
create mode 100644 backend/app/workspaces/models.py
create mode 100644 backend/app/workspaces/routers.py
create mode 100644 backend/app/workspaces/schemas.py
create mode 100644 backend/app/workspaces/utils.py
create mode 100644 backend/migrations/versions/949c9fc0461d_workspace_relationship.py
create mode 100644 backend/migrations/versions/977e7e73ce06_workspace_model.py
create mode 100644 frontend/src/app/(protected)/workspace/create/page.tsx
create mode 100644 frontend/src/app/(protected)/workspace/invite/page.tsx
create mode 100644 frontend/src/app/(protected)/workspace/page.tsx
create mode 100644 frontend/src/app/(protected)/workspace/types.ts
create mode 100644 frontend/src/components/WorkspaceSelector.tsx
diff --git a/backend/app/__init__.py b/backend/app/__init__.py
index 845b9c2..2653972 100644
--- a/backend/app/__init__.py
+++ b/backend/app/__init__.py
@@ -10,6 +10,7 @@
from .users.routers import (
router as users_router,
) # to avoid circular imports
+from .workspaces.routers import router as workspaces_router
from .utils import setup_logger
logger = setup_logger()
@@ -40,6 +41,7 @@ def create_app() -> FastAPI:
app.include_router(auth.router)
app.include_router(users_router)
app.include_router(messages.router)
+ app.include_router(workspaces_router)
origins = [
"http://localhost",
diff --git a/backend/app/auth/dependencies.py b/backend/app/auth/dependencies.py
index 534c87f..d34a593 100644
--- a/backend/app/auth/dependencies.py
+++ b/backend/app/auth/dependencies.py
@@ -12,7 +12,7 @@
from jwt.exceptions import InvalidTokenError
from sqlalchemy.ext.asyncio import AsyncSession
-from ..config import CHECK_API_LIMIT, DEFAULT_API_QUOTA, DEFAULT_EXPERIMENTS_QUOTA, ENV
+from ..config import CHECK_API_LIMIT, DEFAULT_API_QUOTA, DEFAULT_EXPERIMENTS_QUOTA
from ..database import get_async_session
from ..users.models import (
UserDB,
@@ -185,7 +185,7 @@ async def get_verified_user(
return user_db
-def create_access_token(username: str) -> str:
+def create_access_token(username: str, workspace_name: str = None) -> str:
"""
Create an access token for the user
"""
@@ -198,6 +198,9 @@ def create_access_token(username: str) -> str:
payload["iat"] = datetime.now(timezone.utc)
payload["sub"] = username
payload["type"] = "access_token"
+
+ if workspace_name:
+ payload["workspace_name"] = workspace_name
return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
diff --git a/backend/app/auth/routers.py b/backend/app/auth/routers.py
index eec0c4a..c989ae8 100644
--- a/backend/app/auth/routers.py
+++ b/backend/app/auth/routers.py
@@ -6,6 +6,8 @@
from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import AsyncSession
+from ..config import DEFAULT_API_QUOTA, DEFAULT_EXPERIMENTS_QUOTA
+
from ..database import get_async_session, get_redis
from ..email import EmailService
from ..users.models import (
@@ -19,6 +21,7 @@
MessageResponse,
PasswordResetConfirm,
PasswordResetRequest,
+ UserCreate,
)
from ..utils import setup_logger
from .config import NEXT_PUBLIC_GOOGLE_LOGIN_CLIENT_ID
@@ -96,8 +99,17 @@ async def login_google(
except ValueError as e:
raise HTTPException(status_code=401, detail="Invalid token") from e
+ # Import here to avoid circular imports
+ from ..workspaces.models import (
+ create_user_workspace_role,
+ get_user_default_workspace,
+ UserRoles
+ )
+ from ..workspaces.utils import create_workspace
+
+ user_email = idinfo["email"]
user = await authenticate_or_create_google_user(
- request=request, google_email=idinfo["email"], asession=asession
+ request=request, google_email=user_email, asession=asession
)
if not user:
raise HTTPException(
@@ -105,8 +117,39 @@ async def login_google(
detail="Unable to create new user",
)
+ user_db = await get_user_by_username(username=user_email, asession=asession)
+
+ # Create default workspace if user is new (has no workspaces)
+ try:
+ default_workspace = await get_user_default_workspace(asession=asession, user_db=user_db)
+ default_workspace_name = default_workspace.workspace_name
+ except Exception:
+ # User doesn't have a default workspace, create one
+ default_workspace_name = f"{user_email}'s Workspace"
+
+ # Create default workspace
+ workspace_db, _ = await create_workspace(
+ api_daily_quota=DEFAULT_API_QUOTA,
+ asession=asession,
+ content_quota=DEFAULT_EXPERIMENTS_QUOTA,
+ user=UserCreate(
+ role=UserRoles.ADMIN,
+ username=user_email,
+ workspace_name=default_workspace_name,
+ ),
+ is_default=True
+ )
+
+ await create_user_workspace_role(
+ asession=asession,
+ is_default_workspace=True,
+ user_db=user_db,
+ user_role=UserRoles.ADMIN,
+ workspace_db=workspace_db,
+ )
+
return AuthenticationDetails(
- access_token=create_access_token(user.username),
+ access_token=create_access_token(user.username, default_workspace_name),
api_key_first_characters=user.api_key_first_characters,
token_type="bearer",
access_level=user.access_level,
diff --git a/backend/app/auth/utils.py b/backend/app/auth/utils.py
index cf8c15a..05b8f69 100644
--- a/backend/app/auth/utils.py
+++ b/backend/app/auth/utils.py
@@ -6,20 +6,12 @@
from redis.asyncio import Redis
from ..utils import setup_logger
-<<<<<<< HEAD
-from .config import JWT_ALGORITHM, JWT_SECRET
-
-# Token settings
-VERIFICATION_TOKEN_EXPIRE_MINUTES = 1440 # 24 hours
-PASSWORD_RESET_TOKEN_EXPIRE_MINUTES = 30 # 30 minutes
-=======
from .config import (
JWT_ALGORITHM,
JWT_SECRET,
PASSWORD_RESET_TOKEN_EXPIRE_MINUTES,
VERIFICATION_TOKEN_EXPIRE_MINUTES,
)
->>>>>>> origin
logger = setup_logger()
diff --git a/backend/app/contextual_mab/models.py b/backend/app/contextual_mab/models.py
index ca27a8b..da4d699 100644
--- a/backend/app/contextual_mab/models.py
+++ b/backend/app/contextual_mab/models.py
@@ -59,6 +59,7 @@ def to_dict(self) -> dict:
return {
"experiment_id": self.experiment_id,
"user_id": self.user_id,
+ "workspace_id": self.workspace_id,
"name": self.name,
"description": self.description,
"created_datetime_utc": self.created_datetime_utc,
@@ -191,6 +192,7 @@ def to_dict(self) -> dict:
async def save_contextual_mab_to_db(
experiment: ContextualBandit,
user_id: int,
+ workspace_id: int,
asession: AsyncSession,
) -> ContextualBanditDB:
"""
@@ -225,6 +227,7 @@ async def save_contextual_mab_to_db(
name=experiment.name,
description=experiment.description,
user_id=user_id,
+ workspace_id=workspace_id,
is_active=experiment.is_active,
created_datetime_utc=datetime.now(timezone.utc),
n_trials=0,
@@ -243,15 +246,17 @@ async def save_contextual_mab_to_db(
async def get_all_contextual_mabs(
user_id: int,
+ workspace_id: int,
asession: AsyncSession,
) -> Sequence[ContextualBanditDB]:
"""
- Get all the contextual experiments from the database.
+ Get all the contextual experiments from the database for a specific workspace.
"""
statement = (
select(ContextualBanditDB)
.where(
ContextualBanditDB.user_id == user_id,
+ ContextualBanditDB.workspace_id == workspace_id,
)
.order_by(ContextualBanditDB.experiment_id)
)
@@ -260,7 +265,10 @@ async def get_all_contextual_mabs(
async def get_contextual_mab_by_id(
- experiment_id: int, user_id: int, asession: AsyncSession
+ experiment_id: int,
+ user_id: int,
+ workspace_id: int,
+ asession: AsyncSession
) -> ContextualBanditDB | None:
"""
Get the contextual experiment by id.
@@ -268,6 +276,7 @@ async def get_contextual_mab_by_id(
result = await asession.execute(
select(ContextualBanditDB)
.where(ContextualBanditDB.user_id == user_id)
+ .where(ContextualBanditDB.workspace_id == workspace_id)
.where(ContextualBanditDB.experiment_id == experiment_id)
)
@@ -275,7 +284,10 @@ async def get_contextual_mab_by_id(
async def delete_contextual_mab_by_id(
- experiment_id: int, user_id: int, asession: AsyncSession
+ experiment_id: int,
+ user_id: int,
+ workspace_id: int,
+ asession: AsyncSession
) -> None:
"""
Delete the contextual experiment by id.
@@ -289,7 +301,6 @@ async def delete_contextual_mab_by_id(
await asession.execute(
delete(ContextualObservationDB).where(
and_(
- ContextualObservationDB.user_id == ObservationsBaseDB.user_id,
ContextualObservationDB.user_id == user_id,
ContextualObservationDB.experiment_id == experiment_id,
)
@@ -317,6 +328,7 @@ async def delete_contextual_mab_by_id(
and_(
ContextualBanditDB.experiment_id == ExperimentBaseDB.experiment_id,
ContextualBanditDB.user_id == user_id,
+ ContextualBanditDB.workspace_id == workspace_id,
ContextualBanditDB.experiment_id == experiment_id,
)
)
diff --git a/backend/app/contextual_mab/routers.py b/backend/app/contextual_mab/routers.py
index 95a6524..c56aa32 100644
--- a/backend/app/contextual_mab/routers.py
+++ b/backend/app/contextual_mab/routers.py
@@ -4,6 +4,9 @@
from fastapi.exceptions import HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
+from ..workspaces.models import get_user_default_workspace, get_user_role_in_workspace
+from ..workspaces.schemas import UserRoles
+
from ..auth.dependencies import authenticate_key, get_verified_user
from ..database import get_async_session
from ..models import get_notifications_from_db, save_notifications_to_db
@@ -41,7 +44,24 @@ async def create_contextual_mabs(
"""
Create a new contextual experiment with different priors for each context.
"""
- cmab = await save_contextual_mab_to_db(experiment, user_db.user_id, asession)
+ workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db)
+
+ user_role = await get_user_role_in_workspace(
+ asession=asession, user_db=user_db, workspace_db=workspace_db
+ )
+
+ if user_role != UserRoles.ADMIN:
+ raise HTTPException(
+ status_code=403,
+ detail="Only workspace administrators can create experiments.",
+ )
+
+ cmab = await save_contextual_mab_to_db(
+ experiment,
+ user_db.user_id,
+ workspace_db.workspace_id,
+ asession
+ )
notifications = await save_notifications_to_db(
experiment_id=cmab.experiment_id,
user_id=user_db.user_id,
@@ -61,7 +81,13 @@ async def get_contextual_mabs(
"""
Get details of all experiments.
"""
- experiments = await get_all_contextual_mabs(user_db.user_id, asession)
+ workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db)
+
+ experiments = await get_all_contextual_mabs(
+ user_db.user_id,
+ workspace_db.workspace_id,
+ asession
+ )
all_experiments = []
for exp in experiments:
exp_dict = exp.to_dict()
@@ -95,8 +121,13 @@ async def get_contextual_mab(
"""
Get details of experiment with the provided `experiment_id`.
"""
+ workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db)
+
experiment = await get_contextual_mab_by_id(
- experiment_id, user_db.user_id, asession
+ experiment_id,
+ user_db.user_id,
+ workspace_db.workspace_id,
+ asession
)
if experiment is None:
raise HTTPException(
@@ -123,14 +154,34 @@ async def delete_contextual_mab(
Delete the experiment with the provided `experiment_id`.
"""
try:
+ workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db)
+
+ user_role = await get_user_role_in_workspace(
+ asession=asession, user_db=user_db, workspace_db=workspace_db
+ )
+
+ if user_role != UserRoles.ADMIN:
+ raise HTTPException(
+ status_code=403,
+ detail="Only workspace administrators can delete experiments.",
+ )
+
experiment = await get_contextual_mab_by_id(
- experiment_id, user_db.user_id, asession
+ experiment_id,
+ user_db.user_id,
+ workspace_db.workspace_id,
+ asession
)
if experiment is None:
raise HTTPException(
status_code=404, detail=f"Experiment with id {experiment_id} not found"
)
- await delete_contextual_mab_by_id(experiment_id, user_db.user_id, asession)
+ await delete_contextual_mab_by_id(
+ experiment_id,
+ user_db.user_id,
+ workspace_db.workspace_id,
+ asession
+ )
return {"detail": f"Experiment {experiment_id} deleted successfully."}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error: {e}") from e
@@ -146,8 +197,13 @@ async def draw_arm(
"""
Get which arm to pull next for provided experiment.
"""
+ workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db)
+
experiment = await get_contextual_mab_by_id(
- experiment_id, user_db.user_id, asession
+ experiment_id,
+ user_db.user_id,
+ workspace_db.workspace_id,
+ asession
)
if experiment is None:
@@ -190,9 +246,14 @@ async def update_arm(
Update the arm with the provided `arm_id` for the given
`experiment_id` based on the `outcome`.
"""
+ workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db)
+
# Get the experiment and do checks
experiment = await get_contextual_mab_by_id(
- experiment_id, user_db.user_id, asession
+ experiment_id,
+ user_db.user_id,
+ workspace_db.workspace_id,
+ asession
)
if experiment is None:
raise HTTPException(
@@ -284,8 +345,13 @@ async def get_outcomes(
"""
Get the outcomes for the experiment.
"""
+ workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db)
+
experiment = await get_contextual_mab_by_id(
- experiment_id, user_db.user_id, asession
+ experiment_id,
+ user_db.user_id,
+ workspace_db.workspace_id,
+ asession
)
if not experiment:
raise HTTPException(
diff --git a/backend/app/contextual_mab/schemas.py b/backend/app/contextual_mab/schemas.py
index 79f1ee6..e09a8ba 100644
--- a/backend/app/contextual_mab/schemas.py
+++ b/backend/app/contextual_mab/schemas.py
@@ -171,6 +171,7 @@ class ContextualBanditResponse(ContextualBanditBase):
"""
experiment_id: int
+ workspace_id: int
arms: list[ContextualArmResponse]
contexts: list[ContextResponse]
notifications: list[NotificationsResponse]
diff --git a/backend/app/email.py b/backend/app/email.py
index 5776735..9770090 100644
--- a/backend/app/email.py
+++ b/backend/app/email.py
@@ -124,6 +124,65 @@ async def send_password_reset_email(
return await self._send_email(email, subject, html_body, text_body)
+ async def send_workspace_invitation_email(
+ self, email: str, username: str, inviter_email: str, workspace_name: str, user_exists: bool
+ ) -> Dict[str, Any]:
+ """
+ Send workspace invitation email
+ """
+ if user_exists:
+ subject = f"You've been invited to join a workspace: {workspace_name}"
+ html_body = f"""
+
+
+
+ Workspace Invitation
+ Hello {username},
+ You have been invited by {inviter_email} to join the workspace "{workspace_name}".
+ You have been added to this workspace. Log in to access it.
+ Login to Your Account
+
+
+ """
+ text_body = f"""
+ Workspace Invitation
+
+ Hello {username},
+
+ You have been invited by {inviter_email} to join the workspace "{workspace_name}".
+
+ You have been added to this workspace. Log in to access it.
+
+ {FRONTEND_URL}/login
+ """
+ else:
+ subject = f"Invitation to Create an Account and Join a Workspace: {workspace_name}"
+ html_body = f"""
+
+
+
+ Workspace Invitation
+ Hello,
+ You have been invited by {inviter_email} to join the workspace "{workspace_name}".
+ You need to create an account to join this workspace.
+ Create Your Account
+
+
+ """
+ text_body = f"""
+ Workspace Invitation
+
+ Hello,
+
+ You have been invited by {inviter_email} to join the workspace "{workspace_name}".
+
+ You need to create an account to join this workspace.
+
+ {FRONTEND_URL}/signup
+ """
+
+ return await self._send_email(email, subject, html_body, text_body)
+
async def _send_email(
self, recipient: str, subject: str, html_body: str, text_body: str
) -> Dict[str, Any]:
diff --git a/backend/app/mab/models.py b/backend/app/mab/models.py
index 799e293..377498d 100644
--- a/backend/app/mab/models.py
+++ b/backend/app/mab/models.py
@@ -50,6 +50,7 @@ def to_dict(self) -> dict:
return {
"experiment_id": self.experiment_id,
"user_id": self.user_id,
+ "workspace_id": self.workspace_id,
"name": self.name,
"description": self.description,
"created_datetime_utc": self.created_datetime_utc,
@@ -143,6 +144,7 @@ def to_dict(self) -> dict:
async def save_mab_to_db(
experiment: MultiArmedBandit,
user_id: int,
+ workspace_id: int,
asession: AsyncSession,
) -> MultiArmedBanditDB:
"""
@@ -159,6 +161,7 @@ async def save_mab_to_db(
name=experiment.name,
description=experiment.description,
user_id=user_id,
+ workspace_id=workspace_id,
is_active=experiment.is_active,
created_datetime_utc=datetime.now(timezone.utc),
n_trials=0,
@@ -176,15 +179,17 @@ async def save_mab_to_db(
async def get_all_mabs(
user_id: int,
+ workspace_id: int,
asession: AsyncSession,
) -> Sequence[MultiArmedBanditDB]:
"""
- Get all the experiments from the database.
+ Get all the experiments from the database for a specific workspace.
"""
statement = (
select(MultiArmedBanditDB)
.where(
MultiArmedBanditDB.user_id == user_id,
+ MultiArmedBanditDB.workspace_id == workspace_id,
)
.order_by(MultiArmedBanditDB.experiment_id)
)
@@ -193,7 +198,10 @@ async def get_all_mabs(
async def get_mab_by_id(
- experiment_id: int, user_id: int, asession: AsyncSession
+ experiment_id: int,
+ user_id: int,
+ workspace_id: int,
+ asession: AsyncSession
) -> MultiArmedBanditDB | None:
"""
Get the experiment by id.
@@ -201,6 +209,7 @@ async def get_mab_by_id(
result = await asession.execute(
select(MultiArmedBanditDB)
.where(MultiArmedBanditDB.user_id == user_id)
+ .where(MultiArmedBanditDB.workspace_id == workspace_id)
.where(MultiArmedBanditDB.experiment_id == experiment_id)
)
@@ -208,7 +217,10 @@ async def get_mab_by_id(
async def delete_mab_by_id(
- experiment_id: int, user_id: int, asession: AsyncSession
+ experiment_id: int,
+ user_id: int,
+ workspace_id: int,
+ asession: AsyncSession
) -> None:
"""
Delete the experiment by id.
@@ -243,6 +255,7 @@ async def delete_mab_by_id(
MultiArmedBanditDB.experiment_id == experiment_id,
MultiArmedBanditDB.experiment_id == ExperimentBaseDB.experiment_id,
MultiArmedBanditDB.user_id == user_id,
+ MultiArmedBanditDB.workspace_id == workspace_id,
)
)
)
diff --git a/backend/app/mab/routers.py b/backend/app/mab/routers.py
index 7fd838f..fd5af38 100644
--- a/backend/app/mab/routers.py
+++ b/backend/app/mab/routers.py
@@ -5,6 +5,9 @@
from fastapi.exceptions import HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
+from ..workspaces.models import get_user_default_workspace, get_user_role_in_workspace
+from ..workspaces.schemas import UserRoles
+
from ..auth.dependencies import authenticate_key, get_verified_user
from ..database import get_async_session
from ..models import get_notifications_from_db, save_notifications_to_db
@@ -38,9 +41,27 @@ async def create_mab(
asession: AsyncSession = Depends(get_async_session),
) -> MultiArmedBanditResponse:
"""
- Create a new experiment.
+ Create a new experiment in the user's current workspace.
"""
- mab = await save_mab_to_db(experiment, user_db.user_id, asession)
+ workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db)
+
+ user_role = await get_user_role_in_workspace(
+ asession=asession, user_db=user_db, workspace_db=workspace_db
+ )
+
+ if user_role != UserRoles.ADMIN:
+ raise HTTPException(
+ status_code=403,
+ detail="Only workspace administrators can create experiments.",
+ )
+
+ mab = await save_mab_to_db(
+ experiment,
+ user_db.user_id,
+ workspace_db.workspace_id,
+ asession
+ )
+
notifications = await save_notifications_to_db(
experiment_id=mab.experiment_id,
user_id=user_db.user_id,
@@ -60,9 +81,15 @@ async def get_mabs(
asession: AsyncSession = Depends(get_async_session),
) -> list[MultiArmedBanditResponse]:
"""
- Get details of all experiments.
+ Get details of all experiments in the user's current workspace.
"""
- experiments = await get_all_mabs(user_db.user_id, asession)
+ workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db)
+
+ experiments = await get_all_mabs(
+ user_db.user_id,
+ workspace_db.workspace_id,
+ asession
+ )
all_experiments = []
for exp in experiments:
@@ -95,7 +122,14 @@ async def get_mab(
"""
Get details of experiment with the provided `experiment_id`.
"""
- experiment = await get_mab_by_id(experiment_id, user_db.user_id, asession)
+ workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db)
+
+ experiment = await get_mab_by_id(
+ experiment_id,
+ user_db.user_id,
+ workspace_db.workspace_id,
+ asession
+ )
if experiment is None:
raise HTTPException(
@@ -123,12 +157,34 @@ async def delete_mab(
Delete the experiment with the provided `experiment_id`.
"""
try:
- experiment = await get_mab_by_id(experiment_id, user_db.user_id, asession)
+ workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db)
+
+ user_role = await get_user_role_in_workspace(
+ asession=asession, user_db=user_db, workspace_db=workspace_db
+ )
+
+ if user_role != UserRoles.ADMIN:
+ raise HTTPException(
+ status_code=403,
+ detail="Only workspace administrators can delete experiments.",
+ )
+
+ experiment = await get_mab_by_id(
+ experiment_id,
+ user_db.user_id,
+ workspace_db.workspace_id,
+ asession
+ )
if experiment is None:
raise HTTPException(
status_code=404, detail=f"Experiment with id {experiment_id} not found"
)
- await delete_mab_by_id(experiment_id, user_db.user_id, asession)
+ await delete_mab_by_id(
+ experiment_id,
+ user_db.user_id,
+ workspace_db.workspace_id,
+ asession
+ )
return {"message": f"Experiment with id {experiment_id} deleted successfully."}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error: {e}") from e
@@ -143,7 +199,14 @@ async def draw_arm(
"""
Get which arm to pull next for provided experiment.
"""
- experiment = await get_mab_by_id(experiment_id, user_db.user_id, asession)
+ workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db)
+
+ experiment = await get_mab_by_id(
+ experiment_id,
+ user_db.user_id,
+ workspace_db.workspace_id,
+ asession
+ )
if experiment is None:
raise HTTPException(
status_code=404, detail=f"Experiment with id {experiment_id} not found"
@@ -165,8 +228,14 @@ async def update_arm(
Update the arm with the provided `arm_id` for the given
`experiment_id` based on the `outcome`.
"""
- # Get and validate experiment
- experiment = await get_mab_by_id(experiment_id, user_db.user_id, asession)
+ workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db)
+
+ experiment = await get_mab_by_id(
+ experiment_id,
+ user_db.user_id,
+ workspace_db.workspace_id,
+ asession
+ )
if experiment is None:
raise HTTPException(
status_code=404, detail=f"Experiment with id {experiment_id} not found"
@@ -233,7 +302,14 @@ async def get_outcomes(
"""
Get the outcomes for the experiment.
"""
- experiment = await get_mab_by_id(experiment_id, user_db.user_id, asession)
+ workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db)
+
+ experiment = await get_mab_by_id(
+ experiment_id,
+ user_db.user_id,
+ workspace_db.workspace_id,
+ asession
+ )
if not experiment:
raise HTTPException(
status_code=404, detail=f"Experiment with id {experiment_id} not found"
diff --git a/backend/app/mab/schemas.py b/backend/app/mab/schemas.py
index 18cd4c7..6270319 100644
--- a/backend/app/mab/schemas.py
+++ b/backend/app/mab/schemas.py
@@ -172,6 +172,7 @@ class MultiArmedBanditResponse(MultiArmedBanditBase):
"""
experiment_id: int
+ workspace_id: int
arms: list[ArmResponse]
notifications: list[NotificationsResponse]
created_datetime_utc: datetime
diff --git a/backend/app/models.py b/backend/app/models.py
index caa1ac0..5314ebb 100644
--- a/backend/app/models.py
+++ b/backend/app/models.py
@@ -3,7 +3,7 @@
from sqlalchemy import Boolean, DateTime, Enum, ForeignKey, Integer, String, select
from sqlalchemy.ext.asyncio import AsyncSession
-from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
+from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from .schemas import EventType, Notifications
@@ -29,6 +29,9 @@ class ExperimentBaseDB(Base):
user_id: Mapped[int] = mapped_column(
Integer, ForeignKey("users.user_id"), nullable=False
)
+ workspace_id: Mapped[int] = mapped_column(
+ Integer, ForeignKey("workspace.workspace_id"), nullable=False
+ )
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
exp_type: Mapped[str] = mapped_column(String(length=50), nullable=False)
prior_type: Mapped[str] = mapped_column(String(length=50), nullable=False)
@@ -41,6 +44,8 @@ class ExperimentBaseDB(Base):
last_trial_datetime_utc: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=True
)
+ workspace: Mapped["WorkspaceDB"] = relationship("WorkspaceDB", back_populates="experiments")
+
__mapper_args__ = {
"polymorphic_identity": "experiment",
"polymorphic_on": "exp_type",
diff --git a/backend/app/users/models.py b/backend/app/users/models.py
index 4de00f8..6f8bfc2 100644
--- a/backend/app/users/models.py
+++ b/backend/app/users/models.py
@@ -56,8 +56,7 @@ class UserDB(Base):
)
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
is_verified: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
-
- # Relationships for workspaces
+
user_workspaces: Mapped[list["UserWorkspaceDB"]] = relationship(
"UserWorkspaceDB",
back_populates="user",
diff --git a/backend/app/users/routers.py b/backend/app/users/routers.py
index 5980200..90844bb 100644
--- a/backend/app/users/routers.py
+++ b/backend/app/users/routers.py
@@ -6,6 +6,8 @@
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
+from ..config import DEFAULT_API_QUOTA, DEFAULT_EXPERIMENTS_QUOTA
+
from ..auth.dependencies import get_current_user, get_verified_user
from ..auth.utils import generate_verification_token
from ..database import get_async_session, get_redis
@@ -44,9 +46,9 @@ async def create_user(
"""
try:
# Import workspace functionality to avoid circular imports
- from ..workspaces.models import create_user_workspace_role, UserRoles
+ from ..workspaces.models import UserRoles, create_user_workspace_role
from ..workspaces.utils import create_workspace
-
+
# Create the user
new_api_key = generate_key()
user_new = await save_user_to_db(
@@ -60,20 +62,20 @@ async def create_user(
# Create default workspace for the user
default_workspace_name = f"{user_new.username}'s Workspace"
workspace_api_key = generate_key()
-
+
workspace_db, _ = await create_workspace(
api_daily_quota=DEFAULT_API_QUOTA,
asession=asession,
- content_quota=DEFAULT_CONTENT_QUOTA,
+ content_quota=DEFAULT_EXPERIMENTS_QUOTA,
user=UserCreate(
role=UserRoles.ADMIN,
username=user_new.username,
workspace_name=default_workspace_name,
),
is_default=True,
- api_key=workspace_api_key
+ api_key=workspace_api_key,
)
-
+
# Add user to workspace as admin
await create_user_workspace_role(
asession=asession,
diff --git a/backend/app/users/schemas.py b/backend/app/users/schemas.py
index 5bb3294..c2fb6b5 100644
--- a/backend/app/users/schemas.py
+++ b/backend/app/users/schemas.py
@@ -12,6 +12,9 @@ class UserCreate(BaseModel):
username: str
experiments_quota: Optional[int] = None
api_daily_quota: Optional[int] = None
+ workspace_name: Optional[str] = None
+ role: Optional[str] = None
+ is_default_workspace: Optional[bool] = False
model_config = ConfigDict(from_attributes=True)
diff --git a/backend/app/workspaces/__init__.py b/backend/app/workspaces/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/backend/app/workspaces/models.py b/backend/app/workspaces/models.py
new file mode 100644
index 0000000..f7b37a2
--- /dev/null
+++ b/backend/app/workspaces/models.py
@@ -0,0 +1,303 @@
+from datetime import datetime, timezone
+from typing import Sequence, TYPE_CHECKING
+
+import sqlalchemy.sql.functions as func
+from sqlalchemy import (
+ Boolean,
+ DateTime,
+ Enum,
+ ForeignKey,
+ Integer,
+ Row,
+ String,
+ case,
+ exists,
+ select,
+ text,
+ update,
+)
+from sqlalchemy.exc import NoResultFound
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy.orm import Mapped, mapped_column, relationship
+
+from ..users.schemas import UserCreate
+
+from ..models import Base, ExperimentBaseDB
+from ..utils import get_key_hash
+from .schemas import UserRoles, UserCreateWithCode
+
+# Use TYPE_CHECKING for type hints without runtime imports
+if TYPE_CHECKING:
+ from ..users.models import UserDB
+
+
+class UserWorkspaceRoleAlreadyExistsError(Exception):
+ """Exception raised when a user workspace role already exists in the database."""
+
+
+class UserNotFoundInWorkspaceError(Exception):
+ """Exception raised when a user is not found in a workspace in the database."""
+
+
+class WorkspaceDB(Base):
+ """ORM for managing workspaces.
+
+ A workspace is an isolated virtual environment that contains contents that can be
+ accessed and modified by users assigned to that workspace. Workspaces must be
+ unique but can contain duplicated content. Users can be assigned to one more
+ workspaces, with different roles. In other words, there is a MANY-to-MANY
+ relationship between users and workspaces.
+ """
+
+ __tablename__ = "workspace"
+
+ api_daily_quota: Mapped[int | None] = mapped_column(Integer, nullable=True)
+ api_key_first_characters: Mapped[str] = mapped_column(String(5), nullable=True)
+ api_key_updated_datetime_utc: Mapped[datetime] = mapped_column(
+ DateTime(timezone=True), nullable=True
+ )
+ content_quota: Mapped[int | None] = mapped_column(Integer, nullable=True)
+ created_datetime_utc: Mapped[datetime] = mapped_column(
+ DateTime(timezone=True), nullable=False
+ )
+ hashed_api_key: Mapped[str] = mapped_column(String(96), nullable=True, unique=True)
+ updated_datetime_utc: Mapped[datetime] = mapped_column(
+ DateTime(timezone=True), nullable=False
+ )
+ user_workspaces: Mapped[list["UserWorkspaceDB"]] = relationship(
+ "UserWorkspaceDB",
+ back_populates="workspace",
+ cascade="all, delete-orphan",
+ passive_deletes=True,
+ )
+ users: Mapped[list["UserDB"]] = relationship(
+ "UserDB", back_populates="workspaces", secondary="user_workspace", viewonly=True
+ )
+ workspace_id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False)
+ workspace_name: Mapped[str] = mapped_column(String, nullable=False, unique=True)
+ is_default: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
+ experiments: Mapped[list["ExperimentBaseDB"]] = relationship(
+ "ExperimentBaseDB", back_populates="workspace", cascade="all, delete-orphan"
+ )
+
+ def __repr__(self) -> str:
+ """Define the string representation for the `WorkspaceDB` class."""
+ return f""
+
+
+class UserWorkspaceDB(Base):
+ """ORM for managing user in workspaces."""
+
+ __tablename__ = "user_workspace"
+
+ created_datetime_utc: Mapped[datetime] = mapped_column(
+ DateTime(timezone=True), nullable=False
+ )
+ default_workspace: Mapped[bool] = mapped_column(
+ Boolean,
+ nullable=False,
+ server_default=text("false"), # Ensures existing rows default to false
+ )
+ updated_datetime_utc: Mapped[datetime] = mapped_column(
+ DateTime(timezone=True), nullable=False
+ )
+ user: Mapped["UserDB"] = relationship("UserDB", back_populates="user_workspaces")
+ user_id: Mapped[int] = mapped_column(
+ Integer, ForeignKey("users.user_id", ondelete="CASCADE"), primary_key=True
+ )
+ user_role: Mapped[UserRoles] = mapped_column(
+ Enum(UserRoles, native_enum=False), nullable=False
+ )
+ workspace: Mapped["WorkspaceDB"] = relationship(
+ "WorkspaceDB", back_populates="user_workspaces"
+ )
+ workspace_id: Mapped[int] = mapped_column(
+ Integer,
+ ForeignKey("workspace.workspace_id", ondelete="CASCADE"),
+ primary_key=True,
+ )
+
+ def __repr__(self) -> str:
+ """Define the string representation for the `UserWorkspaceDB` class."""
+ return f"."
+
+
+async def check_if_user_has_default_workspace(
+ *, asession: AsyncSession, user_db: "UserDB"
+) -> bool | None:
+ """Check if a user has an assigned default workspace."""
+ stmt = select(
+ exists().where(
+ UserWorkspaceDB.user_id == user_db.user_id,
+ UserWorkspaceDB.default_workspace.is_(True),
+ )
+ )
+ result = await asession.execute(stmt)
+ return result.scalar()
+
+
+async def get_user_default_workspace(
+ *, asession: AsyncSession, user_db: "UserDB"
+) -> WorkspaceDB:
+ """Retrieve the default workspace for a given user."""
+ stmt = (
+ select(WorkspaceDB)
+ .join(UserWorkspaceDB, UserWorkspaceDB.workspace_id == WorkspaceDB.workspace_id)
+ .where(
+ UserWorkspaceDB.user_id == user_db.user_id,
+ UserWorkspaceDB.default_workspace.is_(True),
+ )
+ .limit(1)
+ )
+
+ result = await asession.execute(stmt)
+ default_workspace_db = result.scalar_one()
+ return default_workspace_db
+
+
+async def get_user_workspaces(
+ *, asession: AsyncSession, user_db: "UserDB"
+) -> Sequence[WorkspaceDB]:
+ """Retrieve all workspaces a user belongs to."""
+ stmt = (
+ select(WorkspaceDB)
+ .join(UserWorkspaceDB, UserWorkspaceDB.workspace_id == WorkspaceDB.workspace_id)
+ .where(UserWorkspaceDB.user_id == user_db.user_id)
+ )
+ result = await asession.execute(stmt)
+ return result.scalars().all()
+
+
+async def get_user_role_in_workspace(
+ *, asession: AsyncSession, user_db: "UserDB", workspace_db: WorkspaceDB
+) -> UserRoles | None:
+ """Retrieve the role of a user in a workspace."""
+ stmt = select(UserWorkspaceDB.user_role).where(
+ UserWorkspaceDB.user_id == user_db.user_id,
+ UserWorkspaceDB.workspace_id == workspace_db.workspace_id,
+ )
+ result = await asession.execute(stmt)
+ user_role = result.scalar_one_or_none()
+ return user_role
+
+
+async def update_user_default_workspace(
+ *, asession: AsyncSession, user_db: "UserDB", workspace_db: WorkspaceDB
+) -> None:
+ """Update the default workspace for the user to the specified workspace."""
+ stmt = (
+ update(UserWorkspaceDB)
+ .where(UserWorkspaceDB.user_id == user_db.user_id)
+ .values(
+ default_workspace=case(
+ (UserWorkspaceDB.workspace_id == workspace_db.workspace_id, True),
+ else_=False,
+ ),
+ updated_datetime_utc=datetime.now(timezone.utc)
+ )
+ )
+
+ await asession.execute(stmt)
+ await asession.commit()
+
+
+async def create_user_workspace_role(
+ *,
+ asession: AsyncSession,
+ is_default_workspace: bool = False,
+ user_db: "UserDB",
+ user_role: UserRoles,
+ workspace_db: WorkspaceDB,
+) -> UserWorkspaceDB:
+ """Create a user in a workspace with the specified role."""
+ existing_user_role = await get_user_role_in_workspace(
+ asession=asession, user_db=user_db, workspace_db=workspace_db
+ )
+
+ if existing_user_role is not None:
+ raise UserWorkspaceRoleAlreadyExistsError(
+ f"User '{user_db.username}' with role '{user_role}' in workspace "
+ f"{workspace_db.workspace_name} already exists."
+ )
+
+ user_workspace_role_db = UserWorkspaceDB(
+ created_datetime_utc=datetime.now(timezone.utc),
+ default_workspace=is_default_workspace,
+ updated_datetime_utc=datetime.now(timezone.utc),
+ user_id=user_db.user_id,
+ user_role=user_role,
+ workspace_id=workspace_db.workspace_id,
+ )
+
+ asession.add(user_workspace_role_db)
+ await asession.commit()
+ await asession.refresh(user_workspace_role_db)
+
+ return user_workspace_role_db
+
+
+async def get_workspaces_by_user_role(
+ *, asession: AsyncSession, user_db: "UserDB", user_role: UserRoles
+) -> Sequence[WorkspaceDB]:
+ """Retrieve all workspaces for the user with the specified role."""
+ stmt = (
+ select(WorkspaceDB)
+ .join(UserWorkspaceDB, WorkspaceDB.workspace_id == UserWorkspaceDB.workspace_id)
+ .where(UserWorkspaceDB.user_id == user_db.user_id)
+ .where(UserWorkspaceDB.user_role == user_role)
+ )
+ result = await asession.execute(stmt)
+ return result.scalars().all()
+
+
+async def user_has_admin_role_in_any_workspace(
+ *, asession: AsyncSession, user_db: "UserDB"
+) -> bool:
+ """Check if a user has the ADMIN role in any workspace."""
+ stmt = (
+ select(UserWorkspaceDB.user_id)
+ .where(
+ UserWorkspaceDB.user_id == user_db.user_id,
+ UserWorkspaceDB.user_role == UserRoles.ADMIN,
+ )
+ .limit(1)
+ )
+ result = await asession.execute(stmt)
+ return result.scalar_one_or_none() is not None
+
+
+async def add_existing_user_to_workspace(
+ *,
+ asession: AsyncSession,
+ user: UserCreate,
+ workspace_db: WorkspaceDB,
+) -> UserCreateWithCode:
+ """Add an existing user to a workspace."""
+ # Import here to avoid circular imports
+ from ..users.models import get_user_by_username
+
+ assert user.role is not None
+ user.is_default_workspace = user.is_default_workspace or False
+
+ user_db = await get_user_by_username(username=user.username, asession=asession)
+
+ if user.is_default_workspace:
+ await update_user_default_workspace(
+ asession=asession, user_db=user_db, workspace_db=workspace_db
+ )
+
+ _ = await create_user_workspace_role(
+ asession=asession,
+ is_default_workspace=user.is_default_workspace,
+ user_db=user_db,
+ user_role=user.role,
+ workspace_db=workspace_db,
+ )
+
+ return UserCreateWithCode(
+ is_default_workspace=user.is_default_workspace,
+ recovery_codes=[], # We don't use recovery codes in your implementation
+ role=user.role,
+ username=user_db.username,
+ workspace_name=workspace_db.workspace_name,
+ )
diff --git a/backend/app/workspaces/routers.py b/backend/app/workspaces/routers.py
new file mode 100644
index 0000000..d36fb1a
--- /dev/null
+++ b/backend/app/workspaces/routers.py
@@ -0,0 +1,479 @@
+from typing import Annotated, List
+
+from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, status
+from fastapi.exceptions import HTTPException
+from sqlalchemy.exc import SQLAlchemyError
+from sqlalchemy.ext.asyncio import AsyncSession
+from redis.asyncio import Redis
+
+from ..auth.dependencies import (
+ create_access_token,
+ get_current_user,
+)
+from ..auth.schemas import AuthenticationDetails
+from ..config import DEFAULT_API_QUOTA, DEFAULT_EXPERIMENTS_QUOTA
+from ..database import get_async_session, get_redis
+from ..email import EmailService
+from ..users.models import (
+ UserDB,
+ UserNotFoundError,
+ save_user_to_db,
+ get_user_by_username,
+)
+from ..users.schemas import UserCreate
+from ..utils import generate_key, setup_logger
+from .models import (
+ add_existing_user_to_workspace,
+ check_if_user_has_default_workspace,
+ get_user_default_workspace,
+ get_user_role_in_workspace,
+ get_user_workspaces,
+ get_workspaces_by_user_role,
+ update_user_default_workspace,
+ user_has_admin_role_in_any_workspace,
+)
+from .schemas import (
+ UserRoles,
+ WorkspaceCreate,
+ WorkspaceInvite,
+ WorkspaceInviteResponse,
+ WorkspaceKeyResponse,
+ WorkspaceRetrieve,
+ WorkspaceSwitch,
+ WorkspaceUpdate,
+)
+from .utils import (
+ WorkspaceNotFoundError,
+ create_workspace,
+ get_workspace_by_workspace_id,
+ get_workspace_by_workspace_name,
+ is_workspace_name_valid,
+ update_workspace_api_key,
+ update_workspace_name_and_quotas,
+)
+
+TAG_METADATA = {
+ "name": "Workspace",
+ "description": "_Requires user login._ Only administrator user has access to these "
+ "endpoints and only for the workspaces that they are assigned to.",
+}
+
+router = APIRouter(prefix="/workspace", tags=["Workspace"])
+logger = setup_logger()
+email_service = EmailService()
+
+
+@router.post("/", response_model=WorkspaceRetrieve)
+async def create_workspace_endpoint(
+ calling_user_db: Annotated[UserDB, Depends(get_current_user)],
+ workspace: WorkspaceCreate,
+ asession: AsyncSession = Depends(get_async_session),
+) -> WorkspaceRetrieve:
+ """Create a new workspace. Workspaces can only be created by authenticated users."""
+ if not await check_if_user_has_default_workspace(
+ asession=asession, user_db=calling_user_db
+ ):
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail="User must be assigned to a workspace first before creating new workspaces.",
+ )
+
+ # Check if workspace name is valid
+ if not await is_workspace_name_valid(
+ asession=asession, workspace_name=workspace.workspace_name
+ ):
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=f"Workspace with name '{workspace.workspace_name}' already exists.",
+ )
+
+ # Create new workspace
+ api_key = generate_key()
+ workspace_db, is_new_workspace = await create_workspace(
+ api_daily_quota=workspace.api_daily_quota or DEFAULT_API_QUOTA,
+ asession=asession,
+ content_quota=workspace.content_quota or DEFAULT_EXPERIMENTS_QUOTA,
+ user=UserCreate(
+ role=UserRoles.ADMIN,
+ username=calling_user_db.username,
+ workspace_name=workspace.workspace_name,
+ ),
+ api_key=api_key,
+ )
+
+ if is_new_workspace:
+ # Add the calling user as an admin to the new workspace
+ await add_existing_user_to_workspace(
+ asession=asession,
+ user=UserCreate(
+ is_default_workspace=False, # Don't make it default automatically
+ role=UserRoles.ADMIN,
+ username=calling_user_db.username,
+ workspace_name=workspace_db.workspace_name,
+ ),
+ workspace_db=workspace_db,
+ )
+
+ return WorkspaceRetrieve(
+ api_daily_quota=workspace_db.api_daily_quota,
+ api_key_first_characters=workspace_db.api_key_first_characters,
+ api_key_updated_datetime_utc=workspace_db.api_key_updated_datetime_utc,
+ content_quota=workspace_db.content_quota,
+ created_datetime_utc=workspace_db.created_datetime_utc,
+ updated_datetime_utc=workspace_db.updated_datetime_utc,
+ workspace_id=workspace_db.workspace_id,
+ workspace_name=workspace_db.workspace_name,
+ is_default=workspace_db.is_default,
+ )
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="Workspace already exists.",
+ )
+
+
+@router.get("/", response_model=List[WorkspaceRetrieve])
+async def retrieve_all_workspaces(
+ calling_user_db: Annotated[UserDB, Depends(get_current_user)],
+ asession: AsyncSession = Depends(get_async_session),
+) -> List[WorkspaceRetrieve]:
+ """Return a list of all workspaces the user has access to."""
+ user_workspaces = await get_user_workspaces(asession=asession, user_db=calling_user_db)
+
+ return [
+ WorkspaceRetrieve(
+ api_daily_quota=workspace_db.api_daily_quota,
+ api_key_first_characters=workspace_db.api_key_first_characters,
+ api_key_updated_datetime_utc=workspace_db.api_key_updated_datetime_utc,
+ content_quota=workspace_db.content_quota,
+ created_datetime_utc=workspace_db.created_datetime_utc,
+ updated_datetime_utc=workspace_db.updated_datetime_utc,
+ workspace_id=workspace_db.workspace_id,
+ workspace_name=workspace_db.workspace_name,
+ is_default=workspace_db.is_default,
+ )
+ for workspace_db in user_workspaces
+ ]
+
+
+@router.get("/current", response_model=WorkspaceRetrieve)
+async def get_current_workspace(
+ calling_user_db: Annotated[UserDB, Depends(get_current_user)],
+ asession: AsyncSession = Depends(get_async_session),
+) -> WorkspaceRetrieve:
+ """Return the current default workspace for the user."""
+ try:
+ workspace_db = await get_user_default_workspace(
+ asession=asession, user_db=calling_user_db
+ )
+
+ return WorkspaceRetrieve(
+ api_daily_quota=workspace_db.api_daily_quota,
+ api_key_first_characters=workspace_db.api_key_first_characters,
+ api_key_updated_datetime_utc=workspace_db.api_key_updated_datetime_utc,
+ content_quota=workspace_db.content_quota,
+ created_datetime_utc=workspace_db.created_datetime_utc,
+ updated_datetime_utc=workspace_db.updated_datetime_utc,
+ workspace_id=workspace_db.workspace_id,
+ workspace_name=workspace_db.workspace_name,
+ is_default=workspace_db.is_default,
+ )
+ except Exception as e:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="No default workspace found for the user.",
+ ) from e
+
+
+@router.post("/switch", response_model=AuthenticationDetails)
+async def switch_workspace(
+ calling_user_db: Annotated[UserDB, Depends(get_current_user)],
+ workspace_switch: WorkspaceSwitch,
+ asession: AsyncSession = Depends(get_async_session),
+) -> AuthenticationDetails:
+ """Switch to a different workspace."""
+ # Find the workspace
+ try:
+ workspace_db = await get_workspace_by_workspace_name(
+ asession=asession, workspace_name=workspace_switch.workspace_name
+ )
+ except WorkspaceNotFoundError as e:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=f"Workspace '{workspace_switch.workspace_name}' not found.",
+ ) from e
+
+ # Check if user belongs to this workspace
+ user_role = await get_user_role_in_workspace(
+ asession=asession, user_db=calling_user_db, workspace_db=workspace_db
+ )
+
+ if user_role is None:
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail=f"User does not have access to workspace '{workspace_switch.workspace_name}'.",
+ )
+
+ # Set this workspace as the default for the user
+ await update_user_default_workspace(
+ asession=asession, user_db=calling_user_db, workspace_db=workspace_db
+ )
+
+ # Create a new token with the updated workspace information
+ return AuthenticationDetails(
+ access_level="fullaccess",
+ access_token=create_access_token(calling_user_db.username),
+ token_type="bearer",
+ username=calling_user_db.username,
+ is_verified=calling_user_db.is_verified,
+ api_key_first_characters=calling_user_db.api_key_first_characters,
+ )
+
+
+@router.put("/rotate-key", response_model=WorkspaceKeyResponse)
+async def rotate_workspace_api_key(
+ calling_user_db: Annotated[UserDB, Depends(get_current_user)],
+ asession: AsyncSession = Depends(get_async_session),
+) -> WorkspaceKeyResponse:
+ """Generate a new API key for the current workspace."""
+ try:
+ # Get the user's default workspace
+ workspace_db = await get_user_default_workspace(
+ asession=asession, user_db=calling_user_db
+ )
+
+ # Verify user is an admin in this workspace
+ user_role = await get_user_role_in_workspace(
+ asession=asession, user_db=calling_user_db, workspace_db=workspace_db
+ )
+
+ if user_role != UserRoles.ADMIN:
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail="Only workspace administrators can rotate API keys.",
+ )
+
+ # Generate and update the API key
+ new_api_key = generate_key()
+ asession.add(workspace_db)
+ workspace_db = await update_workspace_api_key(
+ asession=asession, new_api_key=new_api_key, workspace_db=workspace_db
+ )
+
+ return WorkspaceKeyResponse(
+ new_api_key=new_api_key,
+ workspace_name=workspace_db.workspace_name,
+ )
+ except Exception as e:
+ logger.error(f"Error rotating workspace API key: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail="Error rotating workspace API key.",
+ ) from e
+
+
+@router.get("/{workspace_id}", response_model=WorkspaceRetrieve)
+async def retrieve_workspace_by_workspace_id(
+ calling_user_db: Annotated[UserDB, Depends(get_current_user)],
+ workspace_id: int,
+ asession: AsyncSession = Depends(get_async_session),
+) -> WorkspaceRetrieve:
+ """Retrieve a workspace by ID."""
+ try:
+ # Get the workspace
+ workspace_db = await get_workspace_by_workspace_id(
+ asession=asession, workspace_id=workspace_id
+ )
+
+ # Check if user has access to this workspace
+ user_role = await get_user_role_in_workspace(
+ asession=asession, user_db=calling_user_db, workspace_db=workspace_db
+ )
+
+ if user_role is None:
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail=f"User does not have access to workspace with ID {workspace_id}.",
+ )
+
+ return WorkspaceRetrieve(
+ api_daily_quota=workspace_db.api_daily_quota,
+ api_key_first_characters=workspace_db.api_key_first_characters,
+ api_key_updated_datetime_utc=workspace_db.api_key_updated_datetime_utc,
+ content_quota=workspace_db.content_quota,
+ created_datetime_utc=workspace_db.created_datetime_utc,
+ updated_datetime_utc=workspace_db.updated_datetime_utc,
+ workspace_id=workspace_db.workspace_id,
+ workspace_name=workspace_db.workspace_name,
+ is_default=workspace_db.is_default,
+ )
+ except WorkspaceNotFoundError as e:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=f"Workspace with ID {workspace_id} not found.",
+ ) from e
+
+
+@router.put("/{workspace_id}", response_model=WorkspaceRetrieve)
+async def update_workspace_endpoint(
+ workspace_id: int,
+ workspace_update: WorkspaceUpdate,
+ calling_user_db: Annotated[UserDB, Depends(get_current_user)],
+ asession: AsyncSession = Depends(get_async_session),
+) -> WorkspaceRetrieve:
+ """Update workspace details (name, quotas)."""
+ try:
+ # Get the workspace
+ workspace_db = await get_workspace_by_workspace_id(
+ asession=asession, workspace_id=workspace_id
+ )
+
+ # Verify user is an admin in this workspace
+ user_role = await get_user_role_in_workspace(
+ asession=asession, user_db=calling_user_db, workspace_db=workspace_db
+ )
+
+ if user_role != UserRoles.ADMIN:
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail="Only workspace administrators can update workspace details.",
+ )
+
+ # Check if the new workspace name is valid
+ if workspace_update.workspace_name and workspace_update.workspace_name != workspace_db.workspace_name:
+ if not await is_workspace_name_valid(
+ asession=asession, workspace_name=workspace_update.workspace_name
+ ):
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=f"Workspace with name '{workspace_update.workspace_name}' already exists.",
+ )
+
+ # Update the workspace
+ asession.add(workspace_db)
+ updated_workspace = await update_workspace_name_and_quotas(
+ asession=asession, workspace=workspace_update, workspace_db=workspace_db
+ )
+
+ return WorkspaceRetrieve(
+ api_daily_quota=updated_workspace.api_daily_quota,
+ api_key_first_characters=updated_workspace.api_key_first_characters,
+ api_key_updated_datetime_utc=updated_workspace.api_key_updated_datetime_utc,
+ content_quota=updated_workspace.content_quota,
+ created_datetime_utc=updated_workspace.created_datetime_utc,
+ updated_datetime_utc=updated_workspace.updated_datetime_utc,
+ workspace_id=updated_workspace.workspace_id,
+ workspace_name=updated_workspace.workspace_name,
+ is_default=updated_workspace.is_default,
+ )
+ except WorkspaceNotFoundError as e:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=f"Workspace with ID {workspace_id} not found.",
+ ) from e
+ except SQLAlchemyError as e:
+ logger.error(f"Database error when updating workspace: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail="Database error when updating workspace.",
+ ) from e
+
+
+@router.post("/invite", response_model=WorkspaceInviteResponse)
+async def invite_user_to_workspace(
+ calling_user_db: Annotated[UserDB, Depends(get_current_user)],
+ invite: WorkspaceInvite,
+ background_tasks: BackgroundTasks,
+ asession: AsyncSession = Depends(get_async_session),
+ redis: Redis = Depends(get_redis),
+) -> WorkspaceInviteResponse:
+ """Invite a user to join a workspace."""
+ try:
+ # Get the workspace
+ workspace_db = await get_workspace_by_workspace_name(
+ asession=asession, workspace_name=invite.workspace_name
+ )
+
+ # Check if it's a default workspace (users can't invite others to default workspaces)
+ if workspace_db.is_default:
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail="Users cannot be invited to default workspaces.",
+ )
+
+ # Verify caller is an admin in this workspace
+ user_role = await get_user_role_in_workspace(
+ asession=asession, user_db=calling_user_db, workspace_db=workspace_db
+ )
+
+ if user_role != UserRoles.ADMIN:
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail="Only workspace administrators can invite users.",
+ )
+
+ # Check if the invited user exists
+ user_exists = False
+ try:
+ invited_user = await get_user_by_username(
+ username=invite.email, asession=asession
+ )
+ user_exists = True
+
+ # Add existing user to workspace
+ await add_existing_user_to_workspace(
+ asession=asession,
+ user=UserCreate(
+ role=invite.role,
+ username=invite.email,
+ workspace_name=invite.workspace_name,
+ ),
+ workspace_db=workspace_db,
+ )
+
+ # Send invitation email to existing user
+ background_tasks.add_task(
+ email_service.send_workspace_invitation_email,
+ invite.email,
+ invite.email,
+ calling_user_db.username,
+ workspace_db.workspace_name,
+ True, # user exists
+ )
+
+ return WorkspaceInviteResponse(
+ message=f"User {invite.email} has been added to workspace '{workspace_db.workspace_name}'.",
+ email=invite.email,
+ workspace_name=workspace_db.workspace_name,
+ user_exists=True,
+ )
+
+ except UserNotFoundError:
+ # User doesn't exist, send invitation to create account
+ background_tasks.add_task(
+ email_service.send_workspace_invitation_email,
+ invite.email,
+ invite.email,
+ calling_user_db.username,
+ workspace_db.workspace_name,
+ False, # user doesn't exist
+ )
+
+ return WorkspaceInviteResponse(
+ message=f"Invitation sent to {invite.email} to join workspace '{workspace_db.workspace_name}'.",
+ email=invite.email,
+ workspace_name=workspace_db.workspace_name,
+ user_exists=False,
+ )
+
+ except WorkspaceNotFoundError as e:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=f"Workspace '{invite.workspace_name}' not found.",
+ ) from e
+ except Exception as e:
+ logger.error(f"Error inviting user to workspace: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail="Error inviting user to workspace.",
+ ) from e
diff --git a/backend/app/workspaces/schemas.py b/backend/app/workspaces/schemas.py
new file mode 100644
index 0000000..5df8378
--- /dev/null
+++ b/backend/app/workspaces/schemas.py
@@ -0,0 +1,111 @@
+from datetime import datetime
+from enum import Enum
+from typing import Optional, List
+
+from pydantic import BaseModel, ConfigDict, EmailStr
+
+
+class UserRoles(str, Enum):
+ """Enumeration for user roles.
+
+ There are 2 different types of users:
+
+ 1. (Read-Only) Users: These users are assigned to workspaces and can only read the
+ contents within their assigned workspaces. They cannot modify existing
+ contents or add new contents to their workspaces, add or delete users from
+ their workspaces, or add or delete workspaces.
+ 2. Admin Users: These users are assigned to workspaces and can read and modify the
+ contents within their assigned workspaces. They can also add or delete users
+ from their own workspaces and can also add new workspaces or delete their own
+ workspaces. Admin users have no control over workspaces that they are not
+ assigned to.
+ """
+
+ ADMIN = "admin"
+ READ_ONLY = "read_only"
+
+
+class WorkspaceCreate(BaseModel):
+ """Pydantic model for workspace creation."""
+ api_daily_quota: int | None = None
+ content_quota: int | None = None
+ workspace_name: str
+
+ model_config = ConfigDict(from_attributes=True)
+
+
+class WorkspaceKeyResponse(BaseModel):
+ """Pydantic model for updating workspace API key."""
+ new_api_key: str
+ workspace_name: str
+
+ model_config = ConfigDict(from_attributes=True)
+
+
+class WorkspaceRetrieve(BaseModel):
+ """Pydantic model for workspace retrieval."""
+ api_daily_quota: Optional[int] = None
+ api_key_first_characters: Optional[str] = None
+ api_key_updated_datetime_utc: Optional[datetime] = None
+ content_quota: Optional[int] = None
+ created_datetime_utc: datetime
+ updated_datetime_utc: Optional[datetime] = None
+ workspace_id: int
+ workspace_name: str
+ is_default: bool = False
+
+ model_config = ConfigDict(from_attributes=True)
+
+
+class WorkspaceSwitch(BaseModel):
+ """Pydantic model for switching workspaces."""
+ workspace_name: str
+
+ model_config = ConfigDict(from_attributes=True)
+
+
+class WorkspaceUpdate(BaseModel):
+ """Pydantic model for workspace updates."""
+ workspace_name: Optional[str] = None
+
+ model_config = ConfigDict(from_attributes=True)
+
+
+class UserWorkspace(BaseModel):
+ """Pydantic model for user workspace information."""
+ user_role: UserRoles
+ workspace_id: int
+ workspace_name: str
+ is_default: bool = False
+
+ model_config = ConfigDict(from_attributes=True)
+
+
+class UserCreateWithCode(BaseModel):
+ """Pydantic model for user creation with recovery codes."""
+ is_default_workspace: bool = False
+ recovery_codes: List[str] = []
+ role: UserRoles
+ username: str
+ workspace_name: str
+
+ model_config = ConfigDict(from_attributes=True)
+
+
+class WorkspaceInvite(BaseModel):
+ """Pydantic model for inviting users to a workspace."""
+ email: EmailStr
+ role: UserRoles = UserRoles.READ_ONLY
+ workspace_name: str
+
+ model_config = ConfigDict(from_attributes=True)
+
+
+class WorkspaceInviteResponse(BaseModel):
+ """Pydantic model for workspace invite response."""
+ message: str
+ email: EmailStr
+ workspace_name: str
+ user_exists: bool
+
+ model_config = ConfigDict(from_attributes=True)
diff --git a/backend/app/workspaces/utils.py b/backend/app/workspaces/utils.py
new file mode 100644
index 0000000..9ae7d83
--- /dev/null
+++ b/backend/app/workspaces/utils.py
@@ -0,0 +1,141 @@
+from datetime import datetime, timezone
+from typing import Optional
+
+from sqlalchemy import select
+from sqlalchemy.exc import NoResultFound
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from ..users.schemas import UserCreate
+from ..utils import get_key_hash
+from .models import WorkspaceDB
+from .schemas import WorkspaceUpdate
+
+
+class WorkspaceNotFoundError(Exception):
+ """Exception raised when a workspace is not found in the database."""
+
+
+async def check_if_workspaces_exist(*, asession: AsyncSession) -> bool:
+ """Check if workspaces exist in the `WorkspaceDB` database."""
+ stmt = select(WorkspaceDB.workspace_id).limit(1)
+ result = await asession.scalars(stmt)
+ return result.first() is not None
+
+
+async def create_workspace(
+ *,
+ api_daily_quota: Optional[int] = None,
+ asession: AsyncSession,
+ content_quota: Optional[int] = None,
+ user: UserCreate,
+ is_default: bool = False,
+ api_key: Optional[str] = None,
+) -> tuple[WorkspaceDB, bool]:
+ """Create a workspace in the `WorkspaceDB` database. If the workspace already
+ exists, then it is returned.
+ """
+ assert api_daily_quota is None or api_daily_quota >= 0
+ assert content_quota is None or content_quota >= 0
+
+ result = await asession.execute(
+ select(WorkspaceDB).where(WorkspaceDB.workspace_name == user.workspace_name)
+ )
+ workspace_db = result.scalar_one_or_none()
+ new_workspace = False
+
+ if workspace_db is None:
+ new_workspace = True
+ workspace_db = WorkspaceDB(
+ api_daily_quota=api_daily_quota,
+ content_quota=content_quota,
+ created_datetime_utc=datetime.now(timezone.utc),
+ updated_datetime_utc=datetime.now(timezone.utc),
+ workspace_name=user.workspace_name,
+ is_default=is_default
+ )
+
+ if api_key:
+ workspace_db.hashed_api_key = get_key_hash(api_key)
+ workspace_db.api_key_first_characters = api_key[:5]
+ workspace_db.api_key_updated_datetime_utc = datetime.now(timezone.utc)
+
+ asession.add(workspace_db)
+ await asession.commit()
+ await asession.refresh(workspace_db)
+
+ return workspace_db, new_workspace
+
+
+async def get_workspace_by_workspace_id(
+ *, asession: AsyncSession, workspace_id: int
+) -> WorkspaceDB:
+ """Retrieve a workspace by workspace ID."""
+ stmt = select(WorkspaceDB).where(WorkspaceDB.workspace_id == workspace_id)
+ result = await asession.execute(stmt)
+ try:
+ workspace_db = result.scalar_one()
+ return workspace_db
+ except NoResultFound as err:
+ raise WorkspaceNotFoundError(
+ f"Workspace with ID {workspace_id} does not exist."
+ ) from err
+
+
+async def get_workspace_by_workspace_name(
+ *, asession: AsyncSession, workspace_name: str
+) -> WorkspaceDB:
+ """Retrieve a workspace by workspace name."""
+ stmt = select(WorkspaceDB).where(WorkspaceDB.workspace_name == workspace_name)
+ result = await asession.execute(stmt)
+ try:
+ workspace_db = result.scalar_one()
+ return workspace_db
+ except NoResultFound as err:
+ raise WorkspaceNotFoundError(
+ f"Workspace with name {workspace_name} does not exist."
+ ) from err
+
+
+async def is_workspace_name_valid(
+ *, asession: AsyncSession, workspace_name: str
+) -> bool:
+ """Check if a workspace name is valid. A workspace name is valid if it doesn't
+ already exist in the database.
+ """
+ stmt = select(WorkspaceDB).where(WorkspaceDB.workspace_name == workspace_name)
+ result = await asession.execute(stmt)
+ try:
+ result.scalar_one()
+ return False
+ except NoResultFound:
+ return True
+
+
+async def update_workspace_api_key(
+ *, asession: AsyncSession, new_api_key: str, workspace_db: WorkspaceDB
+) -> WorkspaceDB:
+ """Update a workspace API key."""
+ workspace_db.hashed_api_key = get_key_hash(key=new_api_key)
+ workspace_db.api_key_first_characters = new_api_key[:5]
+ workspace_db.api_key_updated_datetime_utc = datetime.now(timezone.utc)
+ workspace_db.updated_datetime_utc = datetime.now(timezone.utc)
+
+ await asession.commit()
+ await asession.refresh(workspace_db)
+
+ return workspace_db
+
+
+async def update_workspace_name_and_quotas(
+ *, asession: AsyncSession, workspace: WorkspaceUpdate, workspace_db: WorkspaceDB
+) -> WorkspaceDB:
+ """Update workspace name"""
+ if workspace.workspace_name is not None:
+ workspace_db.workspace_name = workspace.workspace_name
+
+ workspace_db.updated_datetime_utc = datetime.now(timezone.utc)
+
+ await asession.commit()
+ await asession.refresh(workspace_db)
+
+ return workspace_db
diff --git a/backend/migrations/versions/949c9fc0461d_workspace_relationship.py b/backend/migrations/versions/949c9fc0461d_workspace_relationship.py
new file mode 100644
index 0000000..020cb27
--- /dev/null
+++ b/backend/migrations/versions/949c9fc0461d_workspace_relationship.py
@@ -0,0 +1,32 @@
+"""Workspace relationship
+
+Revision ID: 949c9fc0461d
+Revises: 977e7e73ce06
+Create Date: 2025-04-21 21:20:56.282928
+
+"""
+from typing import Sequence, Union
+
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision: str = '949c9fc0461d'
+down_revision: Union[str, None] = '977e7e73ce06'
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.add_column('experiments_base', sa.Column('workspace_id', sa.Integer(), nullable=False))
+ op.create_foreign_key(None, 'experiments_base', 'workspace', ['workspace_id'], ['workspace_id'])
+ # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_constraint(None, 'experiments_base', type_='foreignkey')
+ op.drop_column('experiments_base', 'workspace_id')
+ # ### end Alembic commands ###
diff --git a/backend/migrations/versions/977e7e73ce06_workspace_model.py b/backend/migrations/versions/977e7e73ce06_workspace_model.py
new file mode 100644
index 0000000..81ff8f5
--- /dev/null
+++ b/backend/migrations/versions/977e7e73ce06_workspace_model.py
@@ -0,0 +1,56 @@
+"""Workspace model
+
+Revision ID: 977e7e73ce06
+Revises: ba1bf29910f5
+Create Date: 2025-04-20 20:17:32.839934
+
+"""
+from typing import Sequence, Union
+
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision: str = '977e7e73ce06'
+down_revision: Union[str, None] = 'ba1bf29910f5'
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.create_table('workspace',
+ sa.Column('api_daily_quota', sa.Integer(), nullable=True),
+ sa.Column('api_key_first_characters', sa.String(length=5), nullable=True),
+ sa.Column('api_key_updated_datetime_utc', sa.DateTime(timezone=True), nullable=True),
+ sa.Column('content_quota', sa.Integer(), nullable=True),
+ sa.Column('created_datetime_utc', sa.DateTime(timezone=True), nullable=False),
+ sa.Column('hashed_api_key', sa.String(length=96), nullable=True),
+ sa.Column('updated_datetime_utc', sa.DateTime(timezone=True), nullable=False),
+ sa.Column('workspace_id', sa.Integer(), nullable=False),
+ sa.Column('workspace_name', sa.String(), nullable=False),
+ sa.Column('is_default', sa.Boolean(), nullable=False),
+ sa.PrimaryKeyConstraint('workspace_id'),
+ sa.UniqueConstraint('hashed_api_key'),
+ sa.UniqueConstraint('workspace_name')
+ )
+ op.create_table('user_workspace',
+ sa.Column('created_datetime_utc', sa.DateTime(timezone=True), nullable=False),
+ sa.Column('default_workspace', sa.Boolean(), server_default=sa.text('false'), nullable=False),
+ sa.Column('updated_datetime_utc', sa.DateTime(timezone=True), nullable=False),
+ sa.Column('user_id', sa.Integer(), nullable=False),
+ sa.Column('user_role', sa.Enum('ADMIN', 'READ_ONLY', name='userroles', native_enum=False), nullable=False),
+ sa.Column('workspace_id', sa.Integer(), nullable=False),
+ sa.ForeignKeyConstraint(['user_id'], ['users.user_id'], ondelete='CASCADE'),
+ sa.ForeignKeyConstraint(['workspace_id'], ['workspace.workspace_id'], ondelete='CASCADE'),
+ sa.PrimaryKeyConstraint('user_id', 'workspace_id')
+ )
+ # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_table('user_workspace')
+ op.drop_table('workspace')
+ # ### end Alembic commands ###
diff --git a/frontend/src/app/(protected)/workspace/create/page.tsx b/frontend/src/app/(protected)/workspace/create/page.tsx
new file mode 100644
index 0000000..2d57b60
--- /dev/null
+++ b/frontend/src/app/(protected)/workspace/create/page.tsx
@@ -0,0 +1,131 @@
+"use client";
+
+import { zodResolver } from "@hookform/resolvers/zod";
+import { useForm } from "react-hook-form";
+import { z } from "zod";
+import { Button } from "@/components/catalyst/button";
+import { Input } from "@/components/catalyst/input";
+import { useAuth } from "@/utils/auth";
+import { apiCalls } from "@/utils/api";
+import { useToast } from "@/hooks/use-toast";
+import { useRouter } from "next/navigation";
+import { useState } from "react";
+import {
+ Fieldset,
+ Field,
+ FieldGroup,
+ Label,
+ Description,
+} from "@/components/catalyst/fieldset";
+import { Card, CardContent, CardHeader, CardTitle, CardDescription } from "@/components/ui/card";
+import { BuildingOfficeIcon } from "@heroicons/react/20/solid";
+
+const formSchema = z.object({
+ workspace_name: z.string().min(3, {
+ message: "Workspace name must be at least 3 characters",
+ }),
+});
+
+type FormValues = z.infer;
+
+export default function CreateWorkspacePage() {
+ const { token, switchWorkspace } = useAuth();
+ const { toast } = useToast();
+ const router = useRouter();
+ const [isSubmitting, setIsSubmitting] = useState(false);
+
+ const form = useForm({
+ resolver: zodResolver(formSchema),
+ defaultValues: {
+ workspace_name: "",
+ },
+ });
+
+ const onSubmit = async (data: FormValues) => {
+ if (!token) {
+ toast({
+ title: "Error",
+ description: "You must be logged in to create a workspace",
+ variant: "destructive",
+ });
+ return;
+ }
+
+ setIsSubmitting(true);
+ try {
+ const response = await apiCalls.createWorkspace(token, data);
+
+ await switchWorkspace(response.workspace_name);
+
+ toast({
+ title: "Success",
+ description: `Workspace "${response.workspace_name}" created and activated!`,
+ });
+
+ router.push("/workspace");
+ } catch (error: any) {
+ toast({
+ title: "Error",
+ description: error.message || "Failed to create workspace",
+ variant: "destructive",
+ });
+ } finally {
+ setIsSubmitting(false);
+ }
+ };
+
+ return (
+
+
+
+
+
+
+ Create new workspace
+
+ Create a new workspace to organize your experiments and team members
+
+
+
+
+
+
+
+
+
+ );
+}
diff --git a/frontend/src/app/(protected)/workspace/invite/page.tsx b/frontend/src/app/(protected)/workspace/invite/page.tsx
new file mode 100644
index 0000000..f610ffc
--- /dev/null
+++ b/frontend/src/app/(protected)/workspace/invite/page.tsx
@@ -0,0 +1,193 @@
+"use client";
+
+import { useState } from "react";
+import { useAuth } from "@/utils/auth";
+import { apiCalls } from "@/utils/api";
+import { useToast } from "@/hooks/use-toast";
+import { z } from "zod";
+import { zodResolver } from "@hookform/resolvers/zod";
+import { useForm } from "react-hook-form";
+import { Button } from "@/components/catalyst/button";
+import { Heading } from "@/components/catalyst/heading";
+import { Input } from "@/components/catalyst/input";
+import {
+ Card,
+ CardContent,
+ CardDescription,
+ CardFooter,
+ CardHeader,
+ CardTitle,
+} from "@/components/ui/card";
+import {
+ Fieldset,
+ Field,
+ FieldGroup,
+ Label,
+ Description,
+} from "@/components/catalyst/fieldset";
+import { Radio, RadioField, RadioGroup } from "@/components/catalyst/radio";
+import { Badge } from "@/components/ui/badge";
+import { EnvelopeIcon, UserPlusIcon } from "@heroicons/react/20/solid";
+
+const inviteSchema = z.object({
+ email: z.string().email({
+ message: "Please enter a valid email address",
+ }),
+ role: z.enum(["ADMIN", "EDITOR", "VIEWER"], {
+ required_error: "Please select a role",
+ }),
+});
+
+type InviteFormValues = z.infer;
+
+export default function InviteUsersPage() {
+ const { token, currentWorkspace } = useAuth();
+ const { toast } = useToast();
+ const [isSubmitting, setIsSubmitting] = useState(false);
+ const [invitedUsers, setInvitedUsers] = useState<{ email: string; role: string; exists: boolean }[]>([]);
+
+ const {
+ register,
+ handleSubmit,
+ reset,
+ formState: { errors },
+ setValue,
+ watch,
+ } = useForm({
+ resolver: zodResolver(inviteSchema),
+ defaultValues: {
+ email: "",
+ role: "VIEWER",
+ },
+ });
+
+ const roleValue = watch("role");
+
+ const onSubmit = async (data: InviteFormValues) => {
+ if (!token || !currentWorkspace) {
+ toast({
+ title: "Error",
+ description: "You must be logged in and have a workspace selected",
+ variant: "destructive",
+ });
+ return;
+ }
+
+ setIsSubmitting(true);
+ try {
+ const response = await apiCalls.inviteUserToWorkspace(token, {
+ email: data.email,
+ role: data.role,
+ workspace_name: currentWorkspace.workspace_name,
+ });
+
+ // Add to invited users list
+ setInvitedUsers([
+ ...invitedUsers,
+ {
+ email: data.email,
+ role: data.role,
+ exists: response.user_exists,
+ },
+ ]);
+
+ // Reset form
+ reset();
+
+ toast({
+ title: "Success",
+ description: `Invitation sent to ${data.email}`,
+ });
+ } catch (error: any) {
+ toast({
+ title: "Error",
+ description: error.message || "Failed to send invitation",
+ variant: "destructive",
+ });
+ } finally {
+ setIsSubmitting(false);
+ }
+ };
+
+ return (
+
+
+ Invite Team Members
+
+
+
+
+
+
+ Send Invitation
+
+ Invite users to join your workspace: {currentWorkspace?.workspace_name}
+
+
+
+
+
+
+
+
+
+ );
+}
diff --git a/frontend/src/app/(protected)/workspace/page.tsx b/frontend/src/app/(protected)/workspace/page.tsx
new file mode 100644
index 0000000..daceb87
--- /dev/null
+++ b/frontend/src/app/(protected)/workspace/page.tsx
@@ -0,0 +1,116 @@
+"use client";
+
+import { useEffect } from "react";
+import { useRouter } from "next/navigation";
+import { useAuth } from "@/utils/auth";
+import { Heading } from "@/components/catalyst/heading";
+import { Card, CardContent, CardHeader, CardTitle, CardDescription } from "@/components/ui/card";
+import { Plus, Users, Settings, Key } from "lucide-react";
+import { Button } from "@/components/catalyst/button";
+
+export default function WorkspacePage() {
+ const { currentWorkspace } = useAuth();
+ const router = useRouter();
+
+ if (!currentWorkspace) {
+ return (
+
+
+
+
+ Something went wrong. Please try again later.
+
+
+
+
+ );
+ }
+ console.log("Current Workspace:", currentWorkspace);
+
+ return (
+
+
+ {currentWorkspace.workspace_name}
+
+
+
+
+
+ Workspace Information
+
+
+
+
+
Name:
+ {currentWorkspace.workspace_name}
+
+
+
API Quota:
+ {currentWorkspace.api_daily_quota.toLocaleString()} calls/day
+
+
+
Experiment Quota:
+ {currentWorkspace.content_quota.toLocaleString()} experiments
+
+
+
Created:
+
+ {new Date(currentWorkspace.created_datetime_utc).toLocaleDateString()}
+
+
+
+
+
+
+
+
+ API Key
+
+
+
+
+
+
+ {currentWorkspace.api_key_first_characters}
+ {"•".repeat(27)}
+
+
+
router.push('/integration')}>
+ Manage API Keys
+
+
+
+ Use this API key to authenticate your API requests. Keep it secret and secure.
+
+
+
+
+
+
+
router.push('/workspace/invite')}>
+
+
+ Invite Team Members
+
+ Invite colleagues to join your workspace
+
+
+
+
+
+
+
router.push('/workspace/create')}>
+
+
+ Create New Workspace
+
+ Create a new workspace for different projects
+
+
+
+
+
+
+
+ );
+}
\ No newline at end of file
diff --git a/frontend/src/app/(protected)/workspace/types.ts b/frontend/src/app/(protected)/workspace/types.ts
new file mode 100644
index 0000000..65c4108
--- /dev/null
+++ b/frontend/src/app/(protected)/workspace/types.ts
@@ -0,0 +1,51 @@
+export enum UserRoles {
+ ADMIN = "ADMIN",
+ EDITOR = "EDITOR",
+ VIEWER = "VIEWER",
+}
+
+export interface Workspace {
+ workspace_id: number;
+ workspace_name: string;
+ api_key_first_characters: string;
+ api_key_updated_datetime_utc: string;
+ api_daily_quota: number;
+ content_quota: number;
+ created_datetime_utc: string;
+ updated_datetime_utc: string;
+ is_default: boolean;
+}
+
+export interface WorkspaceCreate {
+ workspace_name: string;
+ api_daily_quota?: number;
+ content_quota?: number;
+}
+
+export interface WorkspaceUpdate {
+ workspace_name?: string;
+ api_daily_quota?: number;
+ content_quota?: number;
+}
+
+export interface WorkspaceKeyResponse {
+ new_api_key: string;
+ workspace_name: string;
+}
+
+export interface WorkspaceInvite {
+ email: string;
+ role: UserRoles;
+ workspace_name: string;
+}
+
+export interface WorkspaceInviteResponse {
+ message: string;
+ email: string;
+ workspace_name: string;
+ user_exists: boolean;
+}
+
+export interface WorkspaceSwitch {
+ workspace_name: string;
+}
diff --git a/frontend/src/components/WorkspaceSelector.tsx b/frontend/src/components/WorkspaceSelector.tsx
new file mode 100644
index 0000000..c958ea1
--- /dev/null
+++ b/frontend/src/components/WorkspaceSelector.tsx
@@ -0,0 +1,142 @@
+"use client";
+
+import React, { useState } from "react";
+import { useAuth } from "@/utils/auth";
+import { Button } from "@/components/catalyst/button";
+import {
+ Dialog,
+ DialogActions,
+ DialogBody,
+ DialogDescription,
+ DialogTitle,
+} from "@/components/catalyst/dialog";
+import { BuildingOfficeIcon, ChevronUpDownIcon, PlusIcon } from "@heroicons/react/20/solid";
+import {
+ DropdownItem,
+ DropdownLabel,
+ DropdownMenu,
+ DropdownButton,
+ DropdownDivider,
+ Dropdown,
+} from "@/components/catalyst/dropdown";
+import { useToast } from "@/hooks/use-toast";
+import { useRouter } from "next/navigation";
+
+export default function WorkspaceSelector() {
+ const { currentWorkspace, workspaces, switchWorkspace, isLoading } = useAuth();
+ const [isOpen, setIsOpen] = useState(false);
+ const { toast } = useToast();
+ const router = useRouter();
+
+ const handleSwitchWorkspace = async (workspaceName: string) => {
+ try {
+ await switchWorkspace(workspaceName);
+ toast({
+ title: "Workspace Changed",
+ description: `Switched to workspace: ${workspaceName}`,
+ });
+ } catch (error) {
+ toast({
+ title: "Error",
+ description: "Failed to switch workspace",
+ variant: "destructive",
+ });
+ }
+ };
+
+ const handleCreateWorkspace = () => {
+ router.push("/workspace/create");
+ };
+
+ if (isLoading || !currentWorkspace) {
+ return (
+
+
+
+ Loading...
+
+
+
+ );
+ }
+
+ return (
+
+
+
+
+
+ {currentWorkspace.workspace_name}
+
+
+
+
+
+ {workspaces.map((workspace) => (
+ handleSwitchWorkspace(workspace.workspace_name)}
+ >
+
+ {workspace.workspace_name}
+
+ ))}
+
+
+
+
+
+ Create New Workspace
+
+
+
+
+
setIsOpen(false)}>
+ Switch Workspace
+
+ Select a workspace to switch to
+
+
+
+ {workspaces.map((workspace) => (
+
{
+ handleSwitchWorkspace(workspace.workspace_name);
+ setIsOpen(false);
+ }}
+ >
+
{workspace.workspace_name}
+ {workspace.workspace_id === currentWorkspace.workspace_id && (
+
Current
+ )}
+
+ ))}
+
+
+
+ setIsOpen(false)}>
+ Cancel
+
+
+
+ Create New
+
+
+
+
+ );
+}
\ No newline at end of file
diff --git a/frontend/src/components/sidebar.tsx b/frontend/src/components/sidebar.tsx
index fe2d64f..482b0a6 100644
--- a/frontend/src/components/sidebar.tsx
+++ b/frontend/src/components/sidebar.tsx
@@ -10,6 +10,7 @@ import {
SidebarLabel,
SidebarSection,
SidebarSpacer,
+ SidebarDivider,
} from "@/components/catalyst/sidebar";
import { Avatar } from "@/components/catalyst/avatar";
import { Dropdown, DropdownButton } from "@/components/catalyst/dropdown";
@@ -22,11 +23,14 @@ import {
MagnifyingGlassIcon,
QuestionMarkCircleIcon,
SparklesIcon,
+ BuildingOfficeIcon,
+ UserPlusIcon,
} from "@heroicons/react/20/solid";
import { useAuth } from "@/utils/auth";
+import WorkspaceSelector from "./WorkspaceSelector";
export const SidebarComponent = (): React.ReactNode => {
- const { user } = useAuth();
+ const { user, currentWorkspace } = useAuth();
return (
@@ -43,6 +47,19 @@ export const SidebarComponent = (): React.ReactNode => {
+
+ Workspace
+
+
+
+ Manage Workspace
+
+
+
+ Invite Members
+
+
+
{navItems.map((item) => (
@@ -51,7 +68,7 @@ export const SidebarComponent = (): React.ReactNode => {
))}
-
+ {/*
New experiments
Modifying voice of chatbot
@@ -65,7 +82,7 @@ export const SidebarComponent = (): React.ReactNode => {
When to send the message
-
+ */}
diff --git a/frontend/src/utils/api.ts b/frontend/src/utils/api.ts
index 4c8d0df..23ca704 100644
--- a/frontend/src/utils/api.ts
+++ b/frontend/src/utils/api.ts
@@ -90,42 +90,97 @@ const registerUser = async (username: string, password: string) => {
}
};
-const requestPasswordReset = async (username: string) => {
+const getCurrentWorkspace = async (token: string | null) => {
try {
- const response = await api.post("/request-password-reset", { username });
+ const response = await api.get("/workspace/current", {
+ headers: {
+ Authorization: `Bearer ${token}`,
+ },
+ });
return response.data;
} catch (error) {
- throw new Error("Error requesting password reset");
+ throw new Error("Error fetching current workspace");
}
};
-const resetPassword = async (token: string, newPassword: string) => {
+const getAllWorkspaces = async (token: string | null) => {
try {
- const response = await api.post("/reset-password", {
- token,
- new_password: newPassword
+ const response = await api.get("/workspace/", {
+ headers: {
+ Authorization: `Bearer ${token}`,
+ },
});
return response.data;
} catch (error) {
- throw new Error("Error resetting password");
+ throw new Error("Error fetching workspaces");
}
};
-const verifyEmail = async (token: string) => {
+const createWorkspace = async (token: string | null, workspaceData: any) => {
try {
- const response = await api.post("/verify-email", { token });
+ const response = await api.post("/workspace/", workspaceData, {
+ headers: {
+ Authorization: `Bearer ${token}`,
+ },
+ });
return response.data;
} catch (error) {
- throw new Error("Error verifying email");
+ throw new Error("Error creating workspace");
}
};
-const resendVerification = async (username: string) => {
+const updateWorkspace = async (token: string | null, workspaceId: number, workspaceData: any) => {
try {
- const response = await api.post("/resend-verification", { username });
+ const response = await api.put(`/workspace/${workspaceId}`, workspaceData, {
+ headers: {
+ Authorization: `Bearer ${token}`,
+ },
+ });
return response.data;
} catch (error) {
- throw new Error("Error resending verification email");
+ throw new Error("Error updating workspace");
+ }
+};
+
+const switchWorkspace = async (token: string | null, workspaceName: string) => {
+ try {
+ const response = await api.post("/workspace/switch",
+ { workspace_name: workspaceName },
+ {
+ headers: {
+ Authorization: `Bearer ${token}`,
+ },
+ }
+ );
+ return response.data;
+ } catch (error) {
+ throw new Error("Error switching workspace");
+ }
+};
+
+const rotateWorkspaceApiKey = async (token: string | null) => {
+ try {
+ const response = await api.put("/workspace/rotate-key", {}, {
+ headers: {
+ Authorization: `Bearer ${token}`,
+ },
+ });
+ return response.data;
+ } catch (error) {
+ throw new Error("Error rotating workspace API key");
+ }
+};
+
+const inviteUserToWorkspace = async (token: string | null, inviteData: any) => {
+ try {
+ const response = await api.post("/workspace/invite", inviteData, {
+ headers: {
+ Authorization: `Bearer ${token}`,
+ },
+ });
+ return response.data;
+ } catch (error) {
+ throw new Error("Error inviting user to workspace");
}
};
@@ -177,5 +232,12 @@ export const apiCalls = {
resetPassword,
verifyEmail,
resendVerification,
+ getCurrentWorkspace,
+ getAllWorkspaces,
+ createWorkspace,
+ updateWorkspace,
+ switchWorkspace,
+ rotateWorkspaceApiKey,
+ inviteUserToWorkspace,
};
export default api;
diff --git a/frontend/src/utils/auth.tsx b/frontend/src/utils/auth.tsx
index ee07bad..814cf98 100644
--- a/frontend/src/utils/auth.tsx
+++ b/frontend/src/utils/auth.tsx
@@ -3,14 +3,24 @@ import { apiCalls } from "@/utils/api";
import { useRouter, useSearchParams } from "next/navigation";
import { ReactNode, createContext, useContext, useState, useEffect } from "react";
+type Workspace = {
+ workspace_id: number;
+ workspace_name: string;
+ api_key_first_characters: string;
+ is_default: boolean;
+};
+
type AuthContextType = {
token: string | null;
user: string | null;
isVerified: boolean;
isLoading: boolean;
- login: (username: string, password: string) => void;
+ currentWorkspace: Workspace | null;
+ workspaces: Workspace[];
+ login: (username: string, password: string) => Promise;
logout: () => void;
loginError: string | null;
+ switchWorkspace: (workspaceName: string) => Promise;
loginGoogle: ({
client_id,
credential,
@@ -46,10 +56,35 @@ const AuthProvider = ({ children }: AuthProviderProps) => {
const [isVerified, setIsVerified] = useState(false);
const [isLoading, setIsLoading] = useState(!!getInitialToken());
const [loginError, setLoginError] = useState(null);
+ const [currentWorkspace, setCurrentWorkspace] = useState(null);
+ const [workspaces, setWorkspaces] = useState([]);
const searchParams = useSearchParams();
const router = useRouter();
+ useEffect(() => {
+ const loadWorkspaceInfo = async () => {
+ if (token) {
+ try {
+ setIsLoading(true);
+ // Fetch current workspace
+ const currentWorkspaceData = await apiCalls.getCurrentWorkspace(token);
+ setCurrentWorkspace(currentWorkspaceData);
+
+ // Fetch all workspaces
+ const workspacesData = await apiCalls.getAllWorkspaces(token);
+ setWorkspaces(workspacesData);
+ } catch (error) {
+ console.error("Error loading workspace info:", error);
+ } finally {
+ setIsLoading(false);
+ }
+ }
+ };
+
+ loadWorkspaceInfo();
+ }, [token]);
+
// Check verification status on init if token exists
useEffect(() => {
const checkVerificationStatus = async () => {
@@ -101,6 +136,16 @@ const AuthProvider = ({ children }: AuthProviderProps) => {
}
}
+ try {
+ const currentWorkspaceData = await apiCalls.getCurrentWorkspace(access_token);
+ setCurrentWorkspace(currentWorkspaceData);
+
+ const workspacesData = await apiCalls.getAllWorkspaces(access_token);
+ setWorkspaces(workspacesData);
+ } catch (error) {
+ console.error("Error loading workspace info:", error);
+ }
+
// Redirect to verification page if not verified, otherwise to original destination
if (response.is_verified === false) {
router.push("/verification-required");
@@ -123,6 +168,26 @@ const AuthProvider = ({ children }: AuthProviderProps) => {
}
};
+ const switchWorkspace = async (workspaceName: string) => {
+ try {
+ setIsLoading(true);
+ const response = await apiCalls.switchWorkspace(token, workspaceName);
+
+ localStorage.setItem("ee-token", response.access_token);
+ setToken(response.access_token);
+
+ const currentWorkspaceData = await apiCalls.getCurrentWorkspace(response.access_token);
+ setCurrentWorkspace(currentWorkspaceData);
+
+ return response;
+ } catch (error) {
+ console.error("Error switching workspace:", error);
+ throw error;
+ } finally {
+ setIsLoading(false);
+ }
+ };
+
const loginGoogle = async ({
client_id,
credential,
@@ -151,6 +216,16 @@ const AuthProvider = ({ children }: AuthProviderProps) => {
setIsVerified(true);
+ try {
+ const currentWorkspaceData = await apiCalls.getCurrentWorkspace(access_token);
+ setCurrentWorkspace(currentWorkspaceData);
+
+ const workspacesData = await apiCalls.getAllWorkspaces(access_token);
+ setWorkspaces(workspacesData);
+ } catch (error) {
+ console.error("Error loading workspace info:", error);
+ }
+
router.push(sourcePage);
} catch (error) {
setLoginError("Invalid Google credentials");
@@ -167,6 +242,8 @@ const AuthProvider = ({ children }: AuthProviderProps) => {
setUser(null);
setToken(null);
setIsVerified(false);
+ setCurrentWorkspace(null);
+ setWorkspaces([]);
router.push("/login");
};
@@ -175,9 +252,12 @@ const AuthProvider = ({ children }: AuthProviderProps) => {
user,
isVerified,
isLoading,
+ currentWorkspace,
+ workspaces,
login,
loginError,
loginGoogle,
+ switchWorkspace,
logout,
};
From 219aea776ed42d8e1c20422c82690e6d26e4ad52 Mon Sep 17 00:00:00 2001
From: Jay Prakash <0freerunning@gmail.com>
Date: Tue, 22 Apr 2025 03:40:04 +0530
Subject: [PATCH 05/74] File formating
---
backend/app/__init__.py | 2 +-
backend/app/auth/dependencies.py | 2 +-
backend/app/auth/routers.py | 19 ++---
backend/app/contextual_mab/models.py | 10 +--
backend/app/contextual_mab/routers.py | 68 ++++++-----------
backend/app/email.py | 13 +++-
backend/app/mab/models.py | 10 +--
backend/app/mab/routers.py | 72 +++++++-----------
backend/app/models.py | 6 +-
backend/app/users/models.py | 3 +-
backend/app/users/routers.py | 3 +-
backend/app/workspaces/models.py | 15 ++--
backend/app/workspaces/routers.py | 66 ++++++++--------
backend/app/workspaces/schemas.py | 11 ++-
backend/app/workspaces/utils.py | 6 +-
.../949c9fc0461d_workspace_relationship.py | 20 +++--
.../versions/977e7e73ce06_workspace_model.py | 75 +++++++++++--------
.../app/(protected)/workspace/create/page.tsx | 10 +--
.../src/app/(protected)/workspace/page.tsx | 2 +-
.../src/app/(protected)/workspace/types.ts | 1 -
frontend/src/components/WorkspaceSelector.tsx | 6 +-
frontend/src/utils/api.ts | 2 +-
frontend/src/utils/auth.tsx | 12 +--
23 files changed, 204 insertions(+), 230 deletions(-)
diff --git a/backend/app/__init__.py b/backend/app/__init__.py
index 2653972..0f97151 100644
--- a/backend/app/__init__.py
+++ b/backend/app/__init__.py
@@ -10,8 +10,8 @@
from .users.routers import (
router as users_router,
) # to avoid circular imports
-from .workspaces.routers import router as workspaces_router
from .utils import setup_logger
+from .workspaces.routers import router as workspaces_router
logger = setup_logger()
diff --git a/backend/app/auth/dependencies.py b/backend/app/auth/dependencies.py
index d34a593..3fcaefd 100644
--- a/backend/app/auth/dependencies.py
+++ b/backend/app/auth/dependencies.py
@@ -198,7 +198,7 @@ def create_access_token(username: str, workspace_name: str = None) -> str:
payload["iat"] = datetime.now(timezone.utc)
payload["sub"] = username
payload["type"] = "access_token"
-
+
if workspace_name:
payload["workspace_name"] = workspace_name
diff --git a/backend/app/auth/routers.py b/backend/app/auth/routers.py
index c989ae8..04346d9 100644
--- a/backend/app/auth/routers.py
+++ b/backend/app/auth/routers.py
@@ -7,7 +7,6 @@
from sqlalchemy.ext.asyncio import AsyncSession
from ..config import DEFAULT_API_QUOTA, DEFAULT_EXPERIMENTS_QUOTA
-
from ..database import get_async_session, get_redis
from ..email import EmailService
from ..users.models import (
@@ -101,12 +100,12 @@ async def login_google(
# Import here to avoid circular imports
from ..workspaces.models import (
- create_user_workspace_role,
+ UserRoles,
+ create_user_workspace_role,
get_user_default_workspace,
- UserRoles
)
from ..workspaces.utils import create_workspace
-
+
user_email = idinfo["email"]
user = await authenticate_or_create_google_user(
request=request, google_email=user_email, asession=asession
@@ -118,15 +117,17 @@ async def login_google(
)
user_db = await get_user_by_username(username=user_email, asession=asession)
-
+
# Create default workspace if user is new (has no workspaces)
try:
- default_workspace = await get_user_default_workspace(asession=asession, user_db=user_db)
+ default_workspace = await get_user_default_workspace(
+ asession=asession, user_db=user_db
+ )
default_workspace_name = default_workspace.workspace_name
except Exception:
# User doesn't have a default workspace, create one
default_workspace_name = f"{user_email}'s Workspace"
-
+
# Create default workspace
workspace_db, _ = await create_workspace(
api_daily_quota=DEFAULT_API_QUOTA,
@@ -137,9 +138,9 @@ async def login_google(
username=user_email,
workspace_name=default_workspace_name,
),
- is_default=True
+ is_default=True,
)
-
+
await create_user_workspace_role(
asession=asession,
is_default_workspace=True,
diff --git a/backend/app/contextual_mab/models.py b/backend/app/contextual_mab/models.py
index da4d699..f88eaa4 100644
--- a/backend/app/contextual_mab/models.py
+++ b/backend/app/contextual_mab/models.py
@@ -265,10 +265,7 @@ async def get_all_contextual_mabs(
async def get_contextual_mab_by_id(
- experiment_id: int,
- user_id: int,
- workspace_id: int,
- asession: AsyncSession
+ experiment_id: int, user_id: int, workspace_id: int, asession: AsyncSession
) -> ContextualBanditDB | None:
"""
Get the contextual experiment by id.
@@ -284,10 +281,7 @@ async def get_contextual_mab_by_id(
async def delete_contextual_mab_by_id(
- experiment_id: int,
- user_id: int,
- workspace_id: int,
- asession: AsyncSession
+ experiment_id: int, user_id: int, workspace_id: int, asession: AsyncSession
) -> None:
"""
Delete the contextual experiment by id.
diff --git a/backend/app/contextual_mab/routers.py b/backend/app/contextual_mab/routers.py
index c56aa32..e9484bf 100644
--- a/backend/app/contextual_mab/routers.py
+++ b/backend/app/contextual_mab/routers.py
@@ -4,14 +4,13 @@
from fastapi.exceptions import HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
-from ..workspaces.models import get_user_default_workspace, get_user_role_in_workspace
-from ..workspaces.schemas import UserRoles
-
from ..auth.dependencies import authenticate_key, get_verified_user
from ..database import get_async_session
from ..models import get_notifications_from_db, save_notifications_to_db
from ..schemas import ContextType, NotificationsResponse, Outcome
from ..users.models import UserDB
+from ..workspaces.models import get_user_default_workspace, get_user_role_in_workspace
+from ..workspaces.schemas import UserRoles
from .models import (
delete_contextual_mab_by_id,
get_all_contextual_mabs,
@@ -45,22 +44,19 @@ async def create_contextual_mabs(
Create a new contextual experiment with different priors for each context.
"""
workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db)
-
+
user_role = await get_user_role_in_workspace(
asession=asession, user_db=user_db, workspace_db=workspace_db
)
-
+
if user_role != UserRoles.ADMIN:
raise HTTPException(
status_code=403,
detail="Only workspace administrators can create experiments.",
)
-
+
cmab = await save_contextual_mab_to_db(
- experiment,
- user_db.user_id,
- workspace_db.workspace_id,
- asession
+ experiment, user_db.user_id, workspace_db.workspace_id, asession
)
notifications = await save_notifications_to_db(
experiment_id=cmab.experiment_id,
@@ -82,11 +78,9 @@ async def get_contextual_mabs(
Get details of all experiments.
"""
workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db)
-
+
experiments = await get_all_contextual_mabs(
- user_db.user_id,
- workspace_db.workspace_id,
- asession
+ user_db.user_id, workspace_db.workspace_id, asession
)
all_experiments = []
for exp in experiments:
@@ -124,10 +118,7 @@ async def get_contextual_mab(
workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db)
experiment = await get_contextual_mab_by_id(
- experiment_id,
- user_db.user_id,
- workspace_db.workspace_id,
- asession
+ experiment_id, user_db.user_id, workspace_db.workspace_id, asession
)
if experiment is None:
raise HTTPException(
@@ -154,33 +145,29 @@ async def delete_contextual_mab(
Delete the experiment with the provided `experiment_id`.
"""
try:
- workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db)
-
+ workspace_db = await get_user_default_workspace(
+ asession=asession, user_db=user_db
+ )
+
user_role = await get_user_role_in_workspace(
asession=asession, user_db=user_db, workspace_db=workspace_db
)
-
+
if user_role != UserRoles.ADMIN:
raise HTTPException(
status_code=403,
detail="Only workspace administrators can delete experiments.",
)
-
+
experiment = await get_contextual_mab_by_id(
- experiment_id,
- user_db.user_id,
- workspace_db.workspace_id,
- asession
+ experiment_id, user_db.user_id, workspace_db.workspace_id, asession
)
if experiment is None:
raise HTTPException(
status_code=404, detail=f"Experiment with id {experiment_id} not found"
)
await delete_contextual_mab_by_id(
- experiment_id,
- user_db.user_id,
- workspace_db.workspace_id,
- asession
+ experiment_id, user_db.user_id, workspace_db.workspace_id, asession
)
return {"detail": f"Experiment {experiment_id} deleted successfully."}
except Exception as e:
@@ -198,12 +185,9 @@ async def draw_arm(
Get which arm to pull next for provided experiment.
"""
workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db)
-
+
experiment = await get_contextual_mab_by_id(
- experiment_id,
- user_db.user_id,
- workspace_db.workspace_id,
- asession
+ experiment_id, user_db.user_id, workspace_db.workspace_id, asession
)
if experiment is None:
@@ -247,13 +231,10 @@ async def update_arm(
`experiment_id` based on the `outcome`.
"""
workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db)
-
+
# Get the experiment and do checks
experiment = await get_contextual_mab_by_id(
- experiment_id,
- user_db.user_id,
- workspace_db.workspace_id,
- asession
+ experiment_id, user_db.user_id, workspace_db.workspace_id, asession
)
if experiment is None:
raise HTTPException(
@@ -346,12 +327,9 @@ async def get_outcomes(
Get the outcomes for the experiment.
"""
workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db)
-
+
experiment = await get_contextual_mab_by_id(
- experiment_id,
- user_db.user_id,
- workspace_db.workspace_id,
- asession
+ experiment_id, user_db.user_id, workspace_db.workspace_id, asession
)
if not experiment:
raise HTTPException(
diff --git a/backend/app/email.py b/backend/app/email.py
index 9770090..d75d90a 100644
--- a/backend/app/email.py
+++ b/backend/app/email.py
@@ -125,7 +125,12 @@ async def send_password_reset_email(
return await self._send_email(email, subject, html_body, text_body)
async def send_workspace_invitation_email(
- self, email: str, username: str, inviter_email: str, workspace_name: str, user_exists: bool
+ self,
+ email: str,
+ username: str,
+ inviter_email: str,
+ workspace_name: str,
+ user_exists: bool,
) -> Dict[str, Any]:
"""
Send workspace invitation email
@@ -150,7 +155,7 @@ async def send_workspace_invitation_email(
Hello {username},
You have been invited by {inviter_email} to join the workspace "{workspace_name}".
-
+
You have been added to this workspace. Log in to access it.
{FRONTEND_URL}/login
@@ -175,12 +180,12 @@ async def send_workspace_invitation_email(
Hello,
You have been invited by {inviter_email} to join the workspace "{workspace_name}".
-
+
You need to create an account to join this workspace.
{FRONTEND_URL}/signup
"""
-
+
return await self._send_email(email, subject, html_body, text_body)
async def _send_email(
diff --git a/backend/app/mab/models.py b/backend/app/mab/models.py
index 377498d..416193a 100644
--- a/backend/app/mab/models.py
+++ b/backend/app/mab/models.py
@@ -198,10 +198,7 @@ async def get_all_mabs(
async def get_mab_by_id(
- experiment_id: int,
- user_id: int,
- workspace_id: int,
- asession: AsyncSession
+ experiment_id: int, user_id: int, workspace_id: int, asession: AsyncSession
) -> MultiArmedBanditDB | None:
"""
Get the experiment by id.
@@ -217,10 +214,7 @@ async def get_mab_by_id(
async def delete_mab_by_id(
- experiment_id: int,
- user_id: int,
- workspace_id: int,
- asession: AsyncSession
+ experiment_id: int, user_id: int, workspace_id: int, asession: AsyncSession
) -> None:
"""
Delete the experiment by id.
diff --git a/backend/app/mab/routers.py b/backend/app/mab/routers.py
index fd5af38..f700d50 100644
--- a/backend/app/mab/routers.py
+++ b/backend/app/mab/routers.py
@@ -5,14 +5,13 @@
from fastapi.exceptions import HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
-from ..workspaces.models import get_user_default_workspace, get_user_role_in_workspace
-from ..workspaces.schemas import UserRoles
-
from ..auth.dependencies import authenticate_key, get_verified_user
from ..database import get_async_session
from ..models import get_notifications_from_db, save_notifications_to_db
from ..schemas import NotificationsResponse, Outcome, RewardLikelihood
from ..users.models import UserDB
+from ..workspaces.models import get_user_default_workspace, get_user_role_in_workspace
+from ..workspaces.schemas import UserRoles
from .models import (
delete_mab_by_id,
get_all_mabs,
@@ -44,24 +43,21 @@ async def create_mab(
Create a new experiment in the user's current workspace.
"""
workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db)
-
+
user_role = await get_user_role_in_workspace(
asession=asession, user_db=user_db, workspace_db=workspace_db
)
-
+
if user_role != UserRoles.ADMIN:
raise HTTPException(
status_code=403,
detail="Only workspace administrators can create experiments.",
)
-
+
mab = await save_mab_to_db(
- experiment,
- user_db.user_id,
- workspace_db.workspace_id,
- asession
+ experiment, user_db.user_id, workspace_db.workspace_id, asession
)
-
+
notifications = await save_notifications_to_db(
experiment_id=mab.experiment_id,
user_id=user_db.user_id,
@@ -84,11 +80,9 @@ async def get_mabs(
Get details of all experiments in the user's current workspace.
"""
workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db)
-
+
experiments = await get_all_mabs(
- user_db.user_id,
- workspace_db.workspace_id,
- asession
+ user_db.user_id, workspace_db.workspace_id, asession
)
all_experiments = []
@@ -123,12 +117,9 @@ async def get_mab(
Get details of experiment with the provided `experiment_id`.
"""
workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db)
-
+
experiment = await get_mab_by_id(
- experiment_id,
- user_db.user_id,
- workspace_db.workspace_id,
- asession
+ experiment_id, user_db.user_id, workspace_db.workspace_id, asession
)
if experiment is None:
@@ -157,33 +148,29 @@ async def delete_mab(
Delete the experiment with the provided `experiment_id`.
"""
try:
- workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db)
-
+ workspace_db = await get_user_default_workspace(
+ asession=asession, user_db=user_db
+ )
+
user_role = await get_user_role_in_workspace(
asession=asession, user_db=user_db, workspace_db=workspace_db
)
-
+
if user_role != UserRoles.ADMIN:
raise HTTPException(
status_code=403,
detail="Only workspace administrators can delete experiments.",
)
-
+
experiment = await get_mab_by_id(
- experiment_id,
- user_db.user_id,
- workspace_db.workspace_id,
- asession
+ experiment_id, user_db.user_id, workspace_db.workspace_id, asession
)
if experiment is None:
raise HTTPException(
status_code=404, detail=f"Experiment with id {experiment_id} not found"
)
await delete_mab_by_id(
- experiment_id,
- user_db.user_id,
- workspace_db.workspace_id,
- asession
+ experiment_id, user_db.user_id, workspace_db.workspace_id, asession
)
return {"message": f"Experiment with id {experiment_id} deleted successfully."}
except Exception as e:
@@ -200,12 +187,9 @@ async def draw_arm(
Get which arm to pull next for provided experiment.
"""
workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db)
-
+
experiment = await get_mab_by_id(
- experiment_id,
- user_db.user_id,
- workspace_db.workspace_id,
- asession
+ experiment_id, user_db.user_id, workspace_db.workspace_id, asession
)
if experiment is None:
raise HTTPException(
@@ -229,12 +213,9 @@ async def update_arm(
`experiment_id` based on the `outcome`.
"""
workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db)
-
+
experiment = await get_mab_by_id(
- experiment_id,
- user_db.user_id,
- workspace_db.workspace_id,
- asession
+ experiment_id, user_db.user_id, workspace_db.workspace_id, asession
)
if experiment is None:
raise HTTPException(
@@ -303,12 +284,9 @@ async def get_outcomes(
Get the outcomes for the experiment.
"""
workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db)
-
+
experiment = await get_mab_by_id(
- experiment_id,
- user_db.user_id,
- workspace_db.workspace_id,
- asession
+ experiment_id, user_db.user_id, workspace_db.workspace_id, asession
)
if not experiment:
raise HTTPException(
diff --git a/backend/app/models.py b/backend/app/models.py
index 5314ebb..162461b 100644
--- a/backend/app/models.py
+++ b/backend/app/models.py
@@ -44,8 +44,10 @@ class ExperimentBaseDB(Base):
last_trial_datetime_utc: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=True
)
- workspace: Mapped["WorkspaceDB"] = relationship("WorkspaceDB", back_populates="experiments")
-
+ workspace: Mapped["WorkspaceDB"] = relationship(
+ "WorkspaceDB", back_populates="experiments"
+ )
+
__mapper_args__ = {
"polymorphic_identity": "experiment",
"polymorphic_on": "exp_type",
diff --git a/backend/app/users/models.py b/backend/app/users/models.py
index 6f8bfc2..9d68acd 100644
--- a/backend/app/users/models.py
+++ b/backend/app/users/models.py
@@ -11,10 +11,9 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Mapped, mapped_column, relationship
-from ..workspaces.models import UserWorkspaceDB, WorkspaceDB
-
from ..models import Base
from ..utils import get_key_hash, get_password_salted_hash, get_random_string
+from ..workspaces.models import UserWorkspaceDB, WorkspaceDB
from .schemas import UserCreate, UserCreateWithPassword
PASSWORD_LENGTH = 12
diff --git a/backend/app/users/routers.py b/backend/app/users/routers.py
index 90844bb..dee1776 100644
--- a/backend/app/users/routers.py
+++ b/backend/app/users/routers.py
@@ -6,10 +6,9 @@
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
-from ..config import DEFAULT_API_QUOTA, DEFAULT_EXPERIMENTS_QUOTA
-
from ..auth.dependencies import get_current_user, get_verified_user
from ..auth.utils import generate_verification_token
+from ..config import DEFAULT_API_QUOTA, DEFAULT_EXPERIMENTS_QUOTA
from ..database import get_async_session, get_redis
from ..email import EmailService
from ..users.models import (
diff --git a/backend/app/workspaces/models.py b/backend/app/workspaces/models.py
index f7b37a2..9547917 100644
--- a/backend/app/workspaces/models.py
+++ b/backend/app/workspaces/models.py
@@ -1,14 +1,12 @@
from datetime import datetime, timezone
-from typing import Sequence, TYPE_CHECKING
+from typing import TYPE_CHECKING, Sequence
-import sqlalchemy.sql.functions as func
from sqlalchemy import (
Boolean,
DateTime,
Enum,
ForeignKey,
Integer,
- Row,
String,
case,
exists,
@@ -16,15 +14,12 @@
text,
update,
)
-from sqlalchemy.exc import NoResultFound
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Mapped, mapped_column, relationship
-from ..users.schemas import UserCreate
-
from ..models import Base, ExperimentBaseDB
-from ..utils import get_key_hash
-from .schemas import UserRoles, UserCreateWithCode
+from ..users.schemas import UserCreate
+from .schemas import UserCreateWithCode, UserRoles
# Use TYPE_CHECKING for type hints without runtime imports
if TYPE_CHECKING:
@@ -193,7 +188,7 @@ async def update_user_default_workspace(
(UserWorkspaceDB.workspace_id == workspace_db.workspace_id, True),
else_=False,
),
- updated_datetime_utc=datetime.now(timezone.utc)
+ updated_datetime_utc=datetime.now(timezone.utc),
)
)
@@ -275,7 +270,7 @@ async def add_existing_user_to_workspace(
"""Add an existing user to a workspace."""
# Import here to avoid circular imports
from ..users.models import get_user_by_username
-
+
assert user.role is not None
user.is_default_workspace = user.is_default_workspace or False
diff --git a/backend/app/workspaces/routers.py b/backend/app/workspaces/routers.py
index d36fb1a..2cb8800 100644
--- a/backend/app/workspaces/routers.py
+++ b/backend/app/workspaces/routers.py
@@ -2,9 +2,9 @@
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, status
from fastapi.exceptions import HTTPException
+from redis.asyncio import Redis
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
-from redis.asyncio import Redis
from ..auth.dependencies import (
create_access_token,
@@ -17,7 +17,6 @@
from ..users.models import (
UserDB,
UserNotFoundError,
- save_user_to_db,
get_user_by_username,
)
from ..users.schemas import UserCreate
@@ -28,9 +27,7 @@
get_user_default_workspace,
get_user_role_in_workspace,
get_user_workspaces,
- get_workspaces_by_user_role,
update_user_default_workspace,
- user_has_admin_role_in_any_workspace,
)
from .schemas import (
UserRoles,
@@ -113,7 +110,7 @@ async def create_workspace_endpoint(
),
workspace_db=workspace_db,
)
-
+
return WorkspaceRetrieve(
api_daily_quota=workspace_db.api_daily_quota,
api_key_first_characters=workspace_db.api_key_first_characters,
@@ -138,8 +135,10 @@ async def retrieve_all_workspaces(
asession: AsyncSession = Depends(get_async_session),
) -> List[WorkspaceRetrieve]:
"""Return a list of all workspaces the user has access to."""
- user_workspaces = await get_user_workspaces(asession=asession, user_db=calling_user_db)
-
+ user_workspaces = await get_user_workspaces(
+ asession=asession, user_db=calling_user_db
+ )
+
return [
WorkspaceRetrieve(
api_daily_quota=workspace_db.api_daily_quota,
@@ -166,7 +165,7 @@ async def get_current_workspace(
workspace_db = await get_user_default_workspace(
asession=asession, user_db=calling_user_db
)
-
+
return WorkspaceRetrieve(
api_daily_quota=workspace_db.api_daily_quota,
api_key_first_characters=workspace_db.api_key_first_characters,
@@ -207,7 +206,7 @@ async def switch_workspace(
user_role = await get_user_role_in_workspace(
asession=asession, user_db=calling_user_db, workspace_db=workspace_db
)
-
+
if user_role is None:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
@@ -241,25 +240,25 @@ async def rotate_workspace_api_key(
workspace_db = await get_user_default_workspace(
asession=asession, user_db=calling_user_db
)
-
+
# Verify user is an admin in this workspace
user_role = await get_user_role_in_workspace(
asession=asession, user_db=calling_user_db, workspace_db=workspace_db
)
-
+
if user_role != UserRoles.ADMIN:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only workspace administrators can rotate API keys.",
)
-
+
# Generate and update the API key
new_api_key = generate_key()
asession.add(workspace_db)
workspace_db = await update_workspace_api_key(
asession=asession, new_api_key=new_api_key, workspace_db=workspace_db
)
-
+
return WorkspaceKeyResponse(
new_api_key=new_api_key,
workspace_name=workspace_db.workspace_name,
@@ -284,18 +283,18 @@ async def retrieve_workspace_by_workspace_id(
workspace_db = await get_workspace_by_workspace_id(
asession=asession, workspace_id=workspace_id
)
-
+
# Check if user has access to this workspace
user_role = await get_user_role_in_workspace(
asession=asession, user_db=calling_user_db, workspace_db=workspace_db
)
-
+
if user_role is None:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"User does not have access to workspace with ID {workspace_id}.",
)
-
+
return WorkspaceRetrieve(
api_daily_quota=workspace_db.api_daily_quota,
api_key_first_characters=workspace_db.api_key_first_characters,
@@ -327,20 +326,23 @@ async def update_workspace_endpoint(
workspace_db = await get_workspace_by_workspace_id(
asession=asession, workspace_id=workspace_id
)
-
+
# Verify user is an admin in this workspace
user_role = await get_user_role_in_workspace(
asession=asession, user_db=calling_user_db, workspace_db=workspace_db
)
-
+
if user_role != UserRoles.ADMIN:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only workspace administrators can update workspace details.",
)
-
+
# Check if the new workspace name is valid
- if workspace_update.workspace_name and workspace_update.workspace_name != workspace_db.workspace_name:
+ if (
+ workspace_update.workspace_name
+ and workspace_update.workspace_name != workspace_db.workspace_name
+ ):
if not await is_workspace_name_valid(
asession=asession, workspace_name=workspace_update.workspace_name
):
@@ -348,13 +350,13 @@ async def update_workspace_endpoint(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Workspace with name '{workspace_update.workspace_name}' already exists.",
)
-
+
# Update the workspace
asession.add(workspace_db)
updated_workspace = await update_workspace_name_and_quotas(
asession=asession, workspace=workspace_update, workspace_db=workspace_db
)
-
+
return WorkspaceRetrieve(
api_daily_quota=updated_workspace.api_daily_quota,
api_key_first_characters=updated_workspace.api_key_first_characters,
@@ -393,25 +395,25 @@ async def invite_user_to_workspace(
workspace_db = await get_workspace_by_workspace_name(
asession=asession, workspace_name=invite.workspace_name
)
-
+
# Check if it's a default workspace (users can't invite others to default workspaces)
if workspace_db.is_default:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Users cannot be invited to default workspaces.",
)
-
+
# Verify caller is an admin in this workspace
user_role = await get_user_role_in_workspace(
asession=asession, user_db=calling_user_db, workspace_db=workspace_db
)
-
+
if user_role != UserRoles.ADMIN:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only workspace administrators can invite users.",
)
-
+
# Check if the invited user exists
user_exists = False
try:
@@ -419,7 +421,7 @@ async def invite_user_to_workspace(
username=invite.email, asession=asession
)
user_exists = True
-
+
# Add existing user to workspace
await add_existing_user_to_workspace(
asession=asession,
@@ -430,7 +432,7 @@ async def invite_user_to_workspace(
),
workspace_db=workspace_db,
)
-
+
# Send invitation email to existing user
background_tasks.add_task(
email_service.send_workspace_invitation_email,
@@ -440,14 +442,14 @@ async def invite_user_to_workspace(
workspace_db.workspace_name,
True, # user exists
)
-
+
return WorkspaceInviteResponse(
message=f"User {invite.email} has been added to workspace '{workspace_db.workspace_name}'.",
email=invite.email,
workspace_name=workspace_db.workspace_name,
user_exists=True,
)
-
+
except UserNotFoundError:
# User doesn't exist, send invitation to create account
background_tasks.add_task(
@@ -458,14 +460,14 @@ async def invite_user_to_workspace(
workspace_db.workspace_name,
False, # user doesn't exist
)
-
+
return WorkspaceInviteResponse(
message=f"Invitation sent to {invite.email} to join workspace '{workspace_db.workspace_name}'.",
email=invite.email,
workspace_name=workspace_db.workspace_name,
user_exists=False,
)
-
+
except WorkspaceNotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
diff --git a/backend/app/workspaces/schemas.py b/backend/app/workspaces/schemas.py
index 5df8378..512a78b 100644
--- a/backend/app/workspaces/schemas.py
+++ b/backend/app/workspaces/schemas.py
@@ -1,6 +1,6 @@
from datetime import datetime
from enum import Enum
-from typing import Optional, List
+from typing import List, Optional
from pydantic import BaseModel, ConfigDict, EmailStr
@@ -27,6 +27,7 @@ class UserRoles(str, Enum):
class WorkspaceCreate(BaseModel):
"""Pydantic model for workspace creation."""
+
api_daily_quota: int | None = None
content_quota: int | None = None
workspace_name: str
@@ -36,6 +37,7 @@ class WorkspaceCreate(BaseModel):
class WorkspaceKeyResponse(BaseModel):
"""Pydantic model for updating workspace API key."""
+
new_api_key: str
workspace_name: str
@@ -44,6 +46,7 @@ class WorkspaceKeyResponse(BaseModel):
class WorkspaceRetrieve(BaseModel):
"""Pydantic model for workspace retrieval."""
+
api_daily_quota: Optional[int] = None
api_key_first_characters: Optional[str] = None
api_key_updated_datetime_utc: Optional[datetime] = None
@@ -59,6 +62,7 @@ class WorkspaceRetrieve(BaseModel):
class WorkspaceSwitch(BaseModel):
"""Pydantic model for switching workspaces."""
+
workspace_name: str
model_config = ConfigDict(from_attributes=True)
@@ -66,6 +70,7 @@ class WorkspaceSwitch(BaseModel):
class WorkspaceUpdate(BaseModel):
"""Pydantic model for workspace updates."""
+
workspace_name: Optional[str] = None
model_config = ConfigDict(from_attributes=True)
@@ -73,6 +78,7 @@ class WorkspaceUpdate(BaseModel):
class UserWorkspace(BaseModel):
"""Pydantic model for user workspace information."""
+
user_role: UserRoles
workspace_id: int
workspace_name: str
@@ -83,6 +89,7 @@ class UserWorkspace(BaseModel):
class UserCreateWithCode(BaseModel):
"""Pydantic model for user creation with recovery codes."""
+
is_default_workspace: bool = False
recovery_codes: List[str] = []
role: UserRoles
@@ -94,6 +101,7 @@ class UserCreateWithCode(BaseModel):
class WorkspaceInvite(BaseModel):
"""Pydantic model for inviting users to a workspace."""
+
email: EmailStr
role: UserRoles = UserRoles.READ_ONLY
workspace_name: str
@@ -103,6 +111,7 @@ class WorkspaceInvite(BaseModel):
class WorkspaceInviteResponse(BaseModel):
"""Pydantic model for workspace invite response."""
+
message: str
email: EmailStr
workspace_name: str
diff --git a/backend/app/workspaces/utils.py b/backend/app/workspaces/utils.py
index 9ae7d83..4897303 100644
--- a/backend/app/workspaces/utils.py
+++ b/backend/app/workspaces/utils.py
@@ -42,7 +42,7 @@ async def create_workspace(
)
workspace_db = result.scalar_one_or_none()
new_workspace = False
-
+
if workspace_db is None:
new_workspace = True
workspace_db = WorkspaceDB(
@@ -51,9 +51,9 @@ async def create_workspace(
created_datetime_utc=datetime.now(timezone.utc),
updated_datetime_utc=datetime.now(timezone.utc),
workspace_name=user.workspace_name,
- is_default=is_default
+ is_default=is_default,
)
-
+
if api_key:
workspace_db.hashed_api_key = get_key_hash(api_key)
workspace_db.api_key_first_characters = api_key[:5]
diff --git a/backend/migrations/versions/949c9fc0461d_workspace_relationship.py b/backend/migrations/versions/949c9fc0461d_workspace_relationship.py
index 020cb27..d04edf2 100644
--- a/backend/migrations/versions/949c9fc0461d_workspace_relationship.py
+++ b/backend/migrations/versions/949c9fc0461d_workspace_relationship.py
@@ -5,28 +5,32 @@
Create Date: 2025-04-21 21:20:56.282928
"""
+
from typing import Sequence, Union
-from alembic import op
import sqlalchemy as sa
-
+from alembic import op
# revision identifiers, used by Alembic.
-revision: str = '949c9fc0461d'
-down_revision: Union[str, None] = '977e7e73ce06'
+revision: str = "949c9fc0461d"
+down_revision: Union[str, None] = "977e7e73ce06"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
- op.add_column('experiments_base', sa.Column('workspace_id', sa.Integer(), nullable=False))
- op.create_foreign_key(None, 'experiments_base', 'workspace', ['workspace_id'], ['workspace_id'])
+ op.add_column(
+ "experiments_base", sa.Column("workspace_id", sa.Integer(), nullable=False)
+ )
+ op.create_foreign_key(
+ None, "experiments_base", "workspace", ["workspace_id"], ["workspace_id"]
+ )
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_constraint(None, 'experiments_base', type_='foreignkey')
- op.drop_column('experiments_base', 'workspace_id')
+ op.drop_constraint(None, "experiments_base", type_="foreignkey")
+ op.drop_column("experiments_base", "workspace_id")
# ### end Alembic commands ###
diff --git a/backend/migrations/versions/977e7e73ce06_workspace_model.py b/backend/migrations/versions/977e7e73ce06_workspace_model.py
index 81ff8f5..bb4ff38 100644
--- a/backend/migrations/versions/977e7e73ce06_workspace_model.py
+++ b/backend/migrations/versions/977e7e73ce06_workspace_model.py
@@ -5,52 +5,67 @@
Create Date: 2025-04-20 20:17:32.839934
"""
+
from typing import Sequence, Union
-from alembic import op
import sqlalchemy as sa
-
+from alembic import op
# revision identifiers, used by Alembic.
-revision: str = '977e7e73ce06'
-down_revision: Union[str, None] = 'ba1bf29910f5'
+revision: str = "977e7e73ce06"
+down_revision: Union[str, None] = "ba1bf29910f5"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
- op.create_table('workspace',
- sa.Column('api_daily_quota', sa.Integer(), nullable=True),
- sa.Column('api_key_first_characters', sa.String(length=5), nullable=True),
- sa.Column('api_key_updated_datetime_utc', sa.DateTime(timezone=True), nullable=True),
- sa.Column('content_quota', sa.Integer(), nullable=True),
- sa.Column('created_datetime_utc', sa.DateTime(timezone=True), nullable=False),
- sa.Column('hashed_api_key', sa.String(length=96), nullable=True),
- sa.Column('updated_datetime_utc', sa.DateTime(timezone=True), nullable=False),
- sa.Column('workspace_id', sa.Integer(), nullable=False),
- sa.Column('workspace_name', sa.String(), nullable=False),
- sa.Column('is_default', sa.Boolean(), nullable=False),
- sa.PrimaryKeyConstraint('workspace_id'),
- sa.UniqueConstraint('hashed_api_key'),
- sa.UniqueConstraint('workspace_name')
+ op.create_table(
+ "workspace",
+ sa.Column("api_daily_quota", sa.Integer(), nullable=True),
+ sa.Column("api_key_first_characters", sa.String(length=5), nullable=True),
+ sa.Column(
+ "api_key_updated_datetime_utc", sa.DateTime(timezone=True), nullable=True
+ ),
+ sa.Column("content_quota", sa.Integer(), nullable=True),
+ sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False),
+ sa.Column("hashed_api_key", sa.String(length=96), nullable=True),
+ sa.Column("updated_datetime_utc", sa.DateTime(timezone=True), nullable=False),
+ sa.Column("workspace_id", sa.Integer(), nullable=False),
+ sa.Column("workspace_name", sa.String(), nullable=False),
+ sa.Column("is_default", sa.Boolean(), nullable=False),
+ sa.PrimaryKeyConstraint("workspace_id"),
+ sa.UniqueConstraint("hashed_api_key"),
+ sa.UniqueConstraint("workspace_name"),
)
- op.create_table('user_workspace',
- sa.Column('created_datetime_utc', sa.DateTime(timezone=True), nullable=False),
- sa.Column('default_workspace', sa.Boolean(), server_default=sa.text('false'), nullable=False),
- sa.Column('updated_datetime_utc', sa.DateTime(timezone=True), nullable=False),
- sa.Column('user_id', sa.Integer(), nullable=False),
- sa.Column('user_role', sa.Enum('ADMIN', 'READ_ONLY', name='userroles', native_enum=False), nullable=False),
- sa.Column('workspace_id', sa.Integer(), nullable=False),
- sa.ForeignKeyConstraint(['user_id'], ['users.user_id'], ondelete='CASCADE'),
- sa.ForeignKeyConstraint(['workspace_id'], ['workspace.workspace_id'], ondelete='CASCADE'),
- sa.PrimaryKeyConstraint('user_id', 'workspace_id')
+ op.create_table(
+ "user_workspace",
+ sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False),
+ sa.Column(
+ "default_workspace",
+ sa.Boolean(),
+ server_default=sa.text("false"),
+ nullable=False,
+ ),
+ sa.Column("updated_datetime_utc", sa.DateTime(timezone=True), nullable=False),
+ sa.Column("user_id", sa.Integer(), nullable=False),
+ sa.Column(
+ "user_role",
+ sa.Enum("ADMIN", "READ_ONLY", name="userroles", native_enum=False),
+ nullable=False,
+ ),
+ sa.Column("workspace_id", sa.Integer(), nullable=False),
+ sa.ForeignKeyConstraint(["user_id"], ["users.user_id"], ondelete="CASCADE"),
+ sa.ForeignKeyConstraint(
+ ["workspace_id"], ["workspace.workspace_id"], ondelete="CASCADE"
+ ),
+ sa.PrimaryKeyConstraint("user_id", "workspace_id"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
- op.drop_table('user_workspace')
- op.drop_table('workspace')
+ op.drop_table("user_workspace")
+ op.drop_table("workspace")
# ### end Alembic commands ###
diff --git a/frontend/src/app/(protected)/workspace/create/page.tsx b/frontend/src/app/(protected)/workspace/create/page.tsx
index 2d57b60..99aff9e 100644
--- a/frontend/src/app/(protected)/workspace/create/page.tsx
+++ b/frontend/src/app/(protected)/workspace/create/page.tsx
@@ -54,14 +54,14 @@ export default function CreateWorkspacePage() {
setIsSubmitting(true);
try {
const response = await apiCalls.createWorkspace(token, data);
-
+
await switchWorkspace(response.workspace_name);
-
+
toast({
title: "Success",
description: `Workspace "${response.workspace_name}" created and activated!`,
});
-
+
router.push("/workspace");
} catch (error: any) {
toast({
@@ -73,7 +73,7 @@ export default function CreateWorkspacePage() {
setIsSubmitting(false);
}
};
-
+
return (
@@ -106,7 +106,7 @@ export default function CreateWorkspacePage() {
Choose a descriptive name for your new workspace.
-
+
);
-}
\ No newline at end of file
+}
diff --git a/frontend/src/app/(protected)/workspace/types.ts b/frontend/src/app/(protected)/workspace/types.ts
index 65c4108..2d53c31 100644
--- a/frontend/src/app/(protected)/workspace/types.ts
+++ b/frontend/src/app/(protected)/workspace/types.ts
@@ -1,6 +1,5 @@
export enum UserRoles {
ADMIN = "ADMIN",
- EDITOR = "EDITOR",
VIEWER = "VIEWER",
}
diff --git a/frontend/src/components/WorkspaceSelector.tsx b/frontend/src/components/WorkspaceSelector.tsx
index c958ea1..4692888 100644
--- a/frontend/src/components/WorkspaceSelector.tsx
+++ b/frontend/src/components/WorkspaceSelector.tsx
@@ -89,9 +89,9 @@ export default function WorkspaceSelector() {
{workspace.workspace_name}
))}
-
+
-
+
Create New Workspace
@@ -139,4 +139,4 @@ export default function WorkspaceSelector() {
);
-}
\ No newline at end of file
+}
diff --git a/frontend/src/utils/api.ts b/frontend/src/utils/api.ts
index 23ca704..41edbdd 100644
--- a/frontend/src/utils/api.ts
+++ b/frontend/src/utils/api.ts
@@ -144,7 +144,7 @@ const updateWorkspace = async (token: string | null, workspaceId: number, worksp
const switchWorkspace = async (token: string | null, workspaceName: string) => {
try {
- const response = await api.post("/workspace/switch",
+ const response = await api.post("/workspace/switch",
{ workspace_name: workspaceName },
{
headers: {
diff --git a/frontend/src/utils/auth.tsx b/frontend/src/utils/auth.tsx
index 814cf98..ee6650d 100644
--- a/frontend/src/utils/auth.tsx
+++ b/frontend/src/utils/auth.tsx
@@ -70,7 +70,7 @@ const AuthProvider = ({ children }: AuthProviderProps) => {
// Fetch current workspace
const currentWorkspaceData = await apiCalls.getCurrentWorkspace(token);
setCurrentWorkspace(currentWorkspaceData);
-
+
// Fetch all workspaces
const workspacesData = await apiCalls.getAllWorkspaces(token);
setWorkspaces(workspacesData);
@@ -139,7 +139,7 @@ const AuthProvider = ({ children }: AuthProviderProps) => {
try {
const currentWorkspaceData = await apiCalls.getCurrentWorkspace(access_token);
setCurrentWorkspace(currentWorkspaceData);
-
+
const workspacesData = await apiCalls.getAllWorkspaces(access_token);
setWorkspaces(workspacesData);
} catch (error) {
@@ -172,13 +172,13 @@ const AuthProvider = ({ children }: AuthProviderProps) => {
try {
setIsLoading(true);
const response = await apiCalls.switchWorkspace(token, workspaceName);
-
+
localStorage.setItem("ee-token", response.access_token);
setToken(response.access_token);
-
+
const currentWorkspaceData = await apiCalls.getCurrentWorkspace(response.access_token);
setCurrentWorkspace(currentWorkspaceData);
-
+
return response;
} catch (error) {
console.error("Error switching workspace:", error);
@@ -219,7 +219,7 @@ const AuthProvider = ({ children }: AuthProviderProps) => {
try {
const currentWorkspaceData = await apiCalls.getCurrentWorkspace(access_token);
setCurrentWorkspace(currentWorkspaceData);
-
+
const workspacesData = await apiCalls.getAllWorkspaces(access_token);
setWorkspaces(workspacesData);
} catch (error) {
From 43a2c3851473527358d2156111924432c889a056 Mon Sep 17 00:00:00 2001
From: Jay Prakash <0freerunning@gmail.com>
Date: Thu, 1 May 2025 02:04:05 +0530
Subject: [PATCH 06/74] Added workspace removal and user list
---
backend/app/auth/dependencies.py | 2 +-
backend/app/auth/routers.py | 2 +-
backend/app/bayes_ab/models.py | 19 ++-
backend/app/bayes_ab/routers.py | 111 ++++++++++++--
backend/app/bayes_ab/schemas.py | 1 +
backend/app/contextual_mab/routers.py | 2 +-
backend/app/mab/routers.py | 2 +-
backend/app/users/exceptions.py | 6 +
backend/app/users/models.py | 9 +-
backend/app/users/routers.py | 39 ++++-
backend/app/workspaces/models.py | 119 +++++++++++++++
backend/app/workspaces/routers.py | 139 +++++++++++++++++-
backend/app/workspaces/schemas.py | 14 ++
.../949c9fc0461d_workspace_relationship.py | 36 -----
.../versions/977e7e73ce06_workspace_model.py | 71 ---------
.../versions/d9f7a309944e_workspace_model.py | 72 +++++++++
16 files changed, 501 insertions(+), 143 deletions(-)
create mode 100644 backend/app/users/exceptions.py
delete mode 100644 backend/migrations/versions/949c9fc0461d_workspace_relationship.py
delete mode 100644 backend/migrations/versions/977e7e73ce06_workspace_model.py
create mode 100644 backend/migrations/versions/d9f7a309944e_workspace_model.py
diff --git a/backend/app/auth/dependencies.py b/backend/app/auth/dependencies.py
index 63e1d3a..a4d18f3 100644
--- a/backend/app/auth/dependencies.py
+++ b/backend/app/auth/dependencies.py
@@ -16,12 +16,12 @@
from ..database import get_async_session
from ..users.models import (
UserDB,
- UserNotFoundError,
get_user_by_api_key,
get_user_by_username,
save_user_to_db,
update_user_verification_status,
)
+from ..users.exceptions import UserNotFoundError
from ..users.schemas import UserCreate
from ..utils import (
generate_key,
diff --git a/backend/app/auth/routers.py b/backend/app/auth/routers.py
index f9f1934..d447014 100644
--- a/backend/app/auth/routers.py
+++ b/backend/app/auth/routers.py
@@ -10,11 +10,11 @@
from ..database import get_async_session, get_redis
from ..email import EmailService
from ..users.models import (
- UserNotFoundError,
get_user_by_username,
update_user_password,
update_user_verification_status,
)
+from ..users.exceptions import UserNotFoundError
from ..users.schemas import (
EmailVerificationRequest,
MessageResponse,
diff --git a/backend/app/bayes_ab/models.py b/backend/app/bayes_ab/models.py
index adf1745..29c61f7 100644
--- a/backend/app/bayes_ab/models.py
+++ b/backend/app/bayes_ab/models.py
@@ -52,6 +52,7 @@ def to_dict(self) -> dict:
return {
"experiment_id": self.experiment_id,
"user_id": self.user_id,
+ "workspace_id": self.workspace_id,
"name": self.name,
"description": self.description,
"sticky_assignment": self.sticky_assignment,
@@ -156,6 +157,7 @@ def to_dict(self) -> dict:
async def save_bayes_ab_to_db(
ab_experiment: BayesianAB,
user_id: int,
+ workspace_id: int,
asession: AsyncSession,
) -> BayesianABDB:
"""
@@ -180,6 +182,7 @@ async def save_bayes_ab_to_db(
name=ab_experiment.name,
description=ab_experiment.description,
user_id=user_id,
+ workspace_id=workspace_id,
is_active=ab_experiment.is_active,
created_datetime_utc=datetime.now(timezone.utc),
n_trials=0,
@@ -201,14 +204,18 @@ async def save_bayes_ab_to_db(
async def get_all_bayes_ab_experiments(
user_id: int,
+ workspace_id: int,
asession: AsyncSession,
) -> Sequence[BayesianABDB]:
"""
- Get all the A/B experiments from the database.
+ Get all the A/B experiments from the database for a specific workspace.
"""
stmt = (
select(BayesianABDB)
- .where(BayesianABDB.user_id == user_id)
+ .where(
+ BayesianABDB.user_id == user_id,
+ BayesianABDB.workspace_id == workspace_id
+ )
.order_by(BayesianABDB.experiment_id)
)
result = await asession.execute(stmt)
@@ -218,14 +225,16 @@ async def get_all_bayes_ab_experiments(
async def get_bayes_ab_experiment_by_id(
experiment_id: int,
user_id: int,
+ workspace_id: int,
asession: AsyncSession,
) -> BayesianABDB | None:
"""
- Get the A/B experiment by id.
+ Get the A/B experiment by id from a specific workspace.
"""
stmt = select(BayesianABDB).where(
and_(
BayesianABDB.user_id == user_id,
+ BayesianABDB.workspace_id == workspace_id,
BayesianABDB.experiment_id == experiment_id,
)
)
@@ -236,14 +245,16 @@ async def get_bayes_ab_experiment_by_id(
async def delete_bayes_ab_experiment_by_id(
experiment_id: int,
user_id: int,
+ workspace_id: int,
asession: AsyncSession,
) -> None:
"""
- Delete the A/B experiment by id.
+ Delete the A/B experiment by id from a specific workspace.
"""
stmt = delete(BayesianABDB).where(
and_(
BayesianABDB.user_id == user_id,
+ BayesianABDB.workspace_id == workspace_id,
BayesianABDB.experiment_id == experiment_id,
BayesianABDB.experiment_id == ExperimentBaseDB.experiment_id,
)
diff --git a/backend/app/bayes_ab/routers.py b/backend/app/bayes_ab/routers.py
index 5b095e5..bde53e4 100644
--- a/backend/app/bayes_ab/routers.py
+++ b/backend/app/bayes_ab/routers.py
@@ -2,7 +2,7 @@
from typing import Annotated, Optional
from uuid import uuid4
-from fastapi import APIRouter, Depends
+from fastapi import APIRouter, Depends, status
from fastapi.exceptions import HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
@@ -11,6 +11,8 @@
from ..models import get_notifications_from_db, save_notifications_to_db
from ..schemas import NotificationsResponse, Outcome, RewardLikelihood
from ..users.models import UserDB
+from ..workspaces.models import get_user_default_workspace, get_user_role_in_workspace
+from ..workspaces.schemas import UserRoles
from .models import (
BayesianABArmDB,
BayesianABDB,
@@ -44,9 +46,27 @@ async def create_ab_experiment(
asession: AsyncSession = Depends(get_async_session),
) -> BayesianABResponse:
"""
- Create a new experiment.
+ Create a new experiment in the user's current workspace.
"""
- bayes_ab = await save_bayes_ab_to_db(experiment, user_db.user_id, asession)
+ workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db)
+
+ user_role = await get_user_role_in_workspace(
+ asession=asession, user_db=user_db, workspace_db=workspace_db
+ )
+
+ if user_role != UserRoles.ADMIN:
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail="Only workspace administrators can create experiments.",
+ )
+
+ bayes_ab = await save_bayes_ab_to_db(
+ experiment,
+ user_db.user_id,
+ workspace_db.workspace_id,
+ asession
+ )
+
notifications = await save_notifications_to_db(
experiment_id=bayes_ab.experiment_id,
user_id=user_db.user_id,
@@ -66,9 +86,15 @@ async def get_bayes_abs(
asession: AsyncSession = Depends(get_async_session),
) -> list[BayesianABResponse]:
"""
- Get details of all experiments.
+ Get details of all experiments in the user's current workspace.
"""
- experiments = await get_all_bayes_ab_experiments(user_db.user_id, asession)
+ workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db)
+
+ experiments = await get_all_bayes_ab_experiments(
+ user_db.user_id,
+ workspace_db.workspace_id,
+ asession
+ )
all_experiments = []
for exp in experiments:
@@ -101,8 +127,13 @@ async def get_bayes_ab(
"""
Get details of experiment with the provided `experiment_id`.
"""
+ workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db)
+
experiment = await get_bayes_ab_experiment_by_id(
- experiment_id, user_db.user_id, asession
+ experiment_id,
+ user_db.user_id,
+ workspace_db.workspace_id,
+ asession
)
if experiment is None:
@@ -131,14 +162,36 @@ async def delete_bayes_ab(
Delete the experiment with the provided `experiment_id`.
"""
try:
+ workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db)
+
+ user_role = await get_user_role_in_workspace(
+ asession=asession, user_db=user_db, workspace_db=workspace_db
+ )
+
+ if user_role != UserRoles.ADMIN:
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail="Only workspace administrators can delete experiments.",
+ )
+
experiment = await get_bayes_ab_experiment_by_id(
- experiment_id, user_db.user_id, asession
+ experiment_id,
+ user_db.user_id,
+ workspace_db.workspace_id,
+ asession
)
if experiment is None:
raise HTTPException(
status_code=404, detail=f"Experiment with id {experiment_id} not found"
)
- await delete_bayes_ab_experiment_by_id(experiment_id, user_db.user_id, asession)
+
+ await delete_bayes_ab_experiment_by_id(
+ experiment_id,
+ user_db.user_id,
+ workspace_db.workspace_id,
+ asession
+ )
+
return {"message": f"Experiment with id {experiment_id} deleted successfully."}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error: {e}") from e
@@ -154,8 +207,13 @@ async def draw_arm(
"""
Get which arm to pull next for provided experiment.
"""
+ workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db)
+
experiment = await get_bayes_ab_experiment_by_id(
- experiment_id, user_db.user_id, asession
+ experiment_id,
+ user_db.user_id,
+ workspace_db.workspace_id,
+ asession
)
if experiment is None:
@@ -214,9 +272,15 @@ async def save_observation_for_arm(
Update the arm with the provided `arm_id` for the given
`experiment_id` based on the `outcome`.
"""
+ workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db)
+
# Get and validate experiment
experiment, draw = await validate_experiment_and_draw(
- experiment_id, draw_id, user_db.user_id, asession
+ experiment_id=experiment_id,
+ draw_id=draw_id,
+ user_id=user_db.user_id,
+ workspace_id=workspace_db.workspace_id,
+ asession=asession,
)
update_experiment_metadata(experiment)
@@ -246,8 +310,13 @@ async def get_outcomes(
"""
Get the outcomes for the experiment.
"""
+ workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db)
+
experiment = await get_bayes_ab_experiment_by_id(
- experiment_id, user_db.user_id, asession
+ experiment_id,
+ user_db.user_id,
+ workspace_db.workspace_id,
+ asession
)
if not experiment:
raise HTTPException(
@@ -275,9 +344,14 @@ async def update_arms(
"""
Get the outcomes for the experiment.
"""
+ workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db)
+
# Check experiment params
experiment = await get_bayes_ab_experiment_by_id(
- experiment_id, user_db.user_id, asession
+ experiment_id,
+ user_db.user_id,
+ workspace_db.workspace_id,
+ asession
)
if not experiment:
raise HTTPException(
@@ -312,10 +386,19 @@ async def update_arms(
async def validate_experiment_and_draw(
- experiment_id: int, draw_id: str, user_id: int, asession: AsyncSession
+ experiment_id: int,
+ draw_id: str,
+ user_id: int,
+ workspace_id: int,
+ asession: AsyncSession,
) -> tuple[BayesianABDB, BayesianABDrawDB]:
"""Validate the experiment and draw"""
- experiment = await get_bayes_ab_experiment_by_id(experiment_id, user_id, asession)
+ experiment = await get_bayes_ab_experiment_by_id(
+ experiment_id,
+ user_id,
+ workspace_id,
+ asession
+ )
if experiment is None:
raise HTTPException(
status_code=404, detail=f"Experiment with id {experiment_id} not found"
diff --git a/backend/app/bayes_ab/schemas.py b/backend/app/bayes_ab/schemas.py
index 35183fa..b0dba26 100644
--- a/backend/app/bayes_ab/schemas.py
+++ b/backend/app/bayes_ab/schemas.py
@@ -108,6 +108,7 @@ class BayesianABResponse(MultiArmedBanditBase):
"""
experiment_id: int
+ workspace_id: int
arms: list[BayesABArmResponse]
notifications: list[NotificationsResponse]
created_datetime_utc: datetime
diff --git a/backend/app/contextual_mab/routers.py b/backend/app/contextual_mab/routers.py
index 48a9fa2..46f67d7 100644
--- a/backend/app/contextual_mab/routers.py
+++ b/backend/app/contextual_mab/routers.py
@@ -12,9 +12,9 @@
from ..models import get_notifications_from_db, save_notifications_to_db
from ..schemas import ContextType, NotificationsResponse, Outcome, RewardLikelihood
from ..users.models import UserDB
+from ..utils import setup_logger
from ..workspaces.models import get_user_default_workspace, get_user_role_in_workspace
from ..workspaces.schemas import UserRoles
-from ..utils import setup_logger
from .models import (
ContextualArmDB,
ContextualBanditDB,
diff --git a/backend/app/mab/routers.py b/backend/app/mab/routers.py
index 2adf5a2..6adf12a 100644
--- a/backend/app/mab/routers.py
+++ b/backend/app/mab/routers.py
@@ -11,9 +11,9 @@
from ..models import get_notifications_from_db, save_notifications_to_db
from ..schemas import NotificationsResponse, Outcome, RewardLikelihood
from ..users.models import UserDB
+from ..utils import setup_logger
from ..workspaces.models import get_user_default_workspace, get_user_role_in_workspace
from ..workspaces.schemas import UserRoles
-from ..utils import setup_logger
from .models import (
MABArmDB,
MABDrawDB,
diff --git a/backend/app/users/exceptions.py b/backend/app/users/exceptions.py
new file mode 100644
index 0000000..8be0490
--- /dev/null
+++ b/backend/app/users/exceptions.py
@@ -0,0 +1,6 @@
+class UserNotFoundError(Exception):
+ """Exception raised when a user is not found in the database."""
+
+
+class UserAlreadyExistsError(Exception):
+ """Exception raised when a user already exists in the database."""
\ No newline at end of file
diff --git a/backend/app/users/models.py b/backend/app/users/models.py
index 2981195..c5912b9 100644
--- a/backend/app/users/models.py
+++ b/backend/app/users/models.py
@@ -15,18 +15,11 @@
from ..utils import get_key_hash, get_password_salted_hash, get_random_string
from ..workspaces.models import UserWorkspaceDB, WorkspaceDB
from .schemas import UserCreate, UserCreateWithPassword
+from ..users.exceptions import UserAlreadyExistsError, UserNotFoundError
PASSWORD_LENGTH = 12
-class UserNotFoundError(Exception):
- """Exception raised when a user is not found in the database."""
-
-
-class UserAlreadyExistsError(Exception):
- """Exception raised when a user already exists in the database."""
-
-
class UserDB(Base):
"""
SQL Alchemy data model for users
diff --git a/backend/app/users/routers.py b/backend/app/users/routers.py
index 80cd5e3..443dde3 100644
--- a/backend/app/users/routers.py
+++ b/backend/app/users/routers.py
@@ -12,11 +12,11 @@
from ..database import get_async_session, get_redis
from ..email import EmailService
from ..users.models import (
- UserAlreadyExistsError,
UserDB,
save_user_to_db,
update_user_api_key,
)
+from ..users.exceptions import UserAlreadyExistsError
from ..utils import generate_key, setup_logger, update_api_limits
from .schemas import KeyResponse, UserCreate, UserCreateWithPassword, UserRetrieve
@@ -45,8 +45,16 @@ async def create_user(
"""
try:
# Import workspace functionality to avoid circular imports
- from ..workspaces.models import UserRoles, create_user_workspace_role
- from ..workspaces.utils import create_workspace
+ from ..workspaces.models import (
+ UserRoles,
+ create_user_workspace_role,
+ get_pending_invitations_by_email,
+ delete_pending_invitation
+ )
+ from ..workspaces.utils import (
+ create_workspace,
+ get_workspace_by_workspace_id
+ )
# Create the user
new_api_key = generate_key()
@@ -69,6 +77,8 @@ async def create_user(
user=UserCreate(
role=UserRoles.ADMIN,
username=user_new.username,
+ first_name=user_new.first_name,
+ last_name=user_new.last_name,
workspace_name=default_workspace_name,
),
is_default=True,
@@ -84,6 +94,29 @@ async def create_user(
workspace_db=workspace_db,
)
+ # Check for pending invitations
+ pending_invitations = await get_pending_invitations_by_email(
+ asession=asession, email=user_new.username
+ )
+
+ # Process pending invitations
+ for invitation in pending_invitations:
+ invite_workspace = await get_workspace_by_workspace_id(
+ asession=asession, workspace_id=invitation.workspace_id
+ )
+
+ # Add user to the invited workspace
+ await create_user_workspace_role(
+ asession=asession,
+ is_default_workspace=False,
+ user_db=user_new,
+ user_role=invitation.role,
+ workspace_db=invite_workspace,
+ )
+
+ # Delete the invitation
+ await delete_pending_invitation(asession=asession, invitation=invitation)
+
# Send verification email
token = await generate_verification_token(
user_new.user_id, user_new.username, redis
diff --git a/backend/app/workspaces/models.py b/backend/app/workspaces/models.py
index 9547917..34dec0a 100644
--- a/backend/app/workspaces/models.py
+++ b/backend/app/workspaces/models.py
@@ -8,6 +8,7 @@
ForeignKey,
Integer,
String,
+ and_,
case,
exists,
select,
@@ -16,6 +17,8 @@
)
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Mapped, mapped_column, relationship
+from sqlalchemy.exc import NoResultFound
+from ..users.exceptions import UserNotFoundError
from ..models import Base, ExperimentBaseDB
from ..users.schemas import UserCreate
@@ -75,6 +78,10 @@ class WorkspaceDB(Base):
"ExperimentBaseDB", back_populates="workspace", cascade="all, delete-orphan"
)
+ pending_invitations: Mapped[list["PendingInvitationDB"]] = relationship(
+ "PendingInvitationDB", back_populates="workspace", cascade="all, delete-orphan"
+ )
+
def __repr__(self) -> str:
"""Define the string representation for the `WorkspaceDB` class."""
return f""
@@ -117,6 +124,118 @@ def __repr__(self) -> str:
return f"."
+class PendingInvitationDB(Base):
+ """ORM for managing pending workspace invitations."""
+
+ __tablename__ = "pending_invitations"
+
+ invitation_id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ email: Mapped[str] = mapped_column(String, nullable=False)
+ workspace_id: Mapped[int] = mapped_column(
+ Integer, ForeignKey("workspace.workspace_id", ondelete="CASCADE"), nullable=False
+ )
+ role: Mapped[UserRoles] = mapped_column(
+ Enum(UserRoles, native_enum=False), nullable=False
+ )
+ inviter_id: Mapped[int] = mapped_column(
+ Integer, ForeignKey("users.user_id"), nullable=False
+ )
+ created_datetime_utc: Mapped[datetime] = mapped_column(
+ DateTime(timezone=True), nullable=False
+ )
+
+ workspace: Mapped["WorkspaceDB"] = relationship(
+ "WorkspaceDB", back_populates="pending_invitations"
+ )
+
+ def __repr__(self) -> str:
+ return f""
+
+
+async def get_users_in_workspace(
+ *, asession: AsyncSession, workspace_db: WorkspaceDB
+) -> Sequence[UserWorkspaceDB]:
+ """Get all users in a workspace with their roles."""
+ stmt = (
+ select(UserWorkspaceDB)
+ .where(UserWorkspaceDB.workspace_id == workspace_db.workspace_id)
+ )
+ result = await asession.execute(stmt)
+ return result.unique().scalars().all()
+
+async def get_user_by_user_id(
+ user_id: int, asession: AsyncSession
+) -> "UserDB":
+ """Get a user by user ID."""
+ stmt = select(UserDB).where(UserDB.user_id == user_id)
+ result = await asession.execute(stmt)
+ try:
+ return result.scalar_one()
+ except NoResultFound as e:
+ raise UserNotFoundError(f"User with ID {user_id} not found") from e
+
+async def remove_user_from_workspace(
+ *, asession: AsyncSession, user_db: "UserDB", workspace_db: WorkspaceDB
+) -> None:
+ """Remove a user from a workspace."""
+ # Check if user exists in workspace
+ stmt = select(UserWorkspaceDB).where(
+ and_(
+ UserWorkspaceDB.user_id == user_db.user_id,
+ UserWorkspaceDB.workspace_id == workspace_db.workspace_id,
+ )
+ )
+ result = await asession.execute(stmt)
+ user_workspace = result.scalar_one_or_none()
+
+ if not user_workspace:
+ raise UserNotFoundInWorkspaceError(
+ f"User '{user_db.username}' not found in workspace '{workspace_db.workspace_name}'."
+ )
+
+ # Delete the relationship
+ await asession.delete(user_workspace)
+ await asession.commit()
+
+async def create_pending_invitation(
+ *,
+ asession: AsyncSession,
+ email: str,
+ workspace_db: WorkspaceDB,
+ role: UserRoles,
+ inviter_id: int,
+) -> PendingInvitationDB:
+ """Create a pending invitation."""
+ invitation = PendingInvitationDB(
+ email=email,
+ workspace_id=workspace_db.workspace_id,
+ role=role,
+ inviter_id=inviter_id,
+ created_datetime_utc=datetime.now(timezone.utc),
+ )
+
+ asession.add(invitation)
+ await asession.commit()
+ await asession.refresh(invitation)
+
+ return invitation
+
+async def get_pending_invitations_by_email(
+ *, asession: AsyncSession, email: str
+) -> Sequence[PendingInvitationDB]:
+ """Get all pending invitations for an email."""
+ stmt = select(PendingInvitationDB).where(PendingInvitationDB.email == email)
+ result = await asession.execute(stmt)
+ return result.scalars().all()
+
+async def delete_pending_invitation(
+ *, asession: AsyncSession, invitation: PendingInvitationDB
+) -> None:
+ """Delete a pending invitation."""
+ await asession.delete(invitation)
+ await asession.commit()
+
+
async def check_if_user_has_default_workspace(
*, asession: AsyncSession, user_db: "UserDB"
) -> bool | None:
diff --git a/backend/app/workspaces/routers.py b/backend/app/workspaces/routers.py
index 2cb8800..2cbbf63 100644
--- a/backend/app/workspaces/routers.py
+++ b/backend/app/workspaces/routers.py
@@ -6,9 +6,12 @@
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
+from ..users.exceptions import UserNotFoundError
+
from ..auth.dependencies import (
create_access_token,
get_current_user,
+ get_verified_user,
)
from ..auth.schemas import AuthenticationDetails
from ..config import DEFAULT_API_QUOTA, DEFAULT_EXPERIMENTS_QUOTA
@@ -16,17 +19,21 @@
from ..email import EmailService
from ..users.models import (
UserDB,
- UserNotFoundError,
get_user_by_username,
)
-from ..users.schemas import UserCreate
+from ..users.schemas import MessageResponse, UserCreate
from ..utils import generate_key, setup_logger
from .models import (
+ UserNotFoundInWorkspaceError,
add_existing_user_to_workspace,
check_if_user_has_default_workspace,
+ create_pending_invitation,
+ get_user_by_user_id,
get_user_default_workspace,
get_user_role_in_workspace,
get_user_workspaces,
+ get_users_in_workspace,
+ remove_user_from_workspace,
update_user_default_workspace,
)
from .schemas import (
@@ -38,6 +45,7 @@
WorkspaceRetrieve,
WorkspaceSwitch,
WorkspaceUpdate,
+ WorkspaceUserResponse,
)
from .utils import (
WorkspaceNotFoundError,
@@ -93,6 +101,8 @@ async def create_workspace_endpoint(
user=UserCreate(
role=UserRoles.ADMIN,
username=calling_user_db.username,
+ first_name=calling_user_db.first_name,
+ last_name=calling_user_db.last_name,
workspace_name=workspace.workspace_name,
),
api_key=api_key,
@@ -106,6 +116,8 @@ async def create_workspace_endpoint(
is_default_workspace=False, # Don't make it default automatically
role=UserRoles.ADMIN,
username=calling_user_db.username,
+ first_name=calling_user_db.first_name,
+ last_name=calling_user_db.last_name,
workspace_name=workspace_db.workspace_name,
),
workspace_db=workspace_db,
@@ -428,6 +440,8 @@ async def invite_user_to_workspace(
user=UserCreate(
role=invite.role,
username=invite.email,
+ first_name=invited_user.first_name,
+ last_name=invited_user.last_name,
workspace_name=invite.workspace_name,
),
workspace_db=workspace_db,
@@ -451,7 +465,16 @@ async def invite_user_to_workspace(
)
except UserNotFoundError:
- # User doesn't exist, send invitation to create account
+ # User doesn't exist, create pending invitation
+ await create_pending_invitation(
+ asession=asession,
+ email=invite.email,
+ workspace_db=workspace_db,
+ role=invite.role,
+ inviter_id=calling_user_db.user_id,
+ )
+
+ # Send invitation email
background_tasks.add_task(
email_service.send_workspace_invitation_email,
invite.email,
@@ -479,3 +502,113 @@ async def invite_user_to_workspace(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error inviting user to workspace.",
) from e
+
+
+@router.get("/{workspace_id}/users", response_model=list[WorkspaceUserResponse])
+async def get_workspace_users(
+ workspace_id: int,
+ calling_user_db: Annotated[UserDB, Depends(get_verified_user)],
+ asession: AsyncSession = Depends(get_async_session),
+) -> list[WorkspaceUserResponse]:
+ """Get all users in a workspace."""
+ try:
+ workspace_db = await get_workspace_by_workspace_id(
+ asession=asession, workspace_id=workspace_id
+ )
+
+ user_role = await get_user_role_in_workspace(
+ asession=asession, user_db=calling_user_db, workspace_db=workspace_db
+ )
+
+ if user_role is None:
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail=f"User does not have access to workspace with ID {workspace_id}.",
+ )
+
+ user_workspaces = await get_users_in_workspace(
+ asession=asession, workspace_db=workspace_db
+ )
+
+ result = []
+ for uw in user_workspaces:
+ user = await get_user_by_user_id(uw.user_id, asession)
+ result.append(
+ WorkspaceUserResponse(
+ user_id=user.user_id,
+ username=user.username,
+ first_name=user.first_name,
+ last_name=user.last_name,
+ role=uw.user_role,
+ is_default_workspace=uw.default_workspace,
+ created_datetime_utc=uw.created_datetime_utc,
+ )
+ )
+
+ return result
+ except WorkspaceNotFoundError as e:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=f"Workspace with ID {workspace_id} not found.",
+ ) from e
+
+@router.delete("/{workspace_id}/users/{username}", response_model=MessageResponse)
+async def remove_user_from_workspace_endpoint(
+ workspace_id: int,
+ username: str,
+ calling_user_db: Annotated[UserDB, Depends(get_verified_user)],
+ asession: AsyncSession = Depends(get_async_session),
+) -> MessageResponse:
+ """Remove a user from a workspace."""
+ try:
+ workspace_db = await get_workspace_by_workspace_id(
+ asession=asession, workspace_id=workspace_id
+ )
+
+ caller_role = await get_user_role_in_workspace(
+ asession=asession, user_db=calling_user_db, workspace_db=workspace_db
+ )
+
+ if caller_role != UserRoles.ADMIN:
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail="Only workspace administrators can remove users.",
+ )
+
+ if workspace_db.is_default:
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail="Cannot remove users from default workspaces.",
+ )
+
+ if username == calling_user_db.username:
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail="You cannot remove yourself from a workspace.",
+ )
+
+ try:
+ user_to_remove = await get_user_by_username(username=username, asession=asession)
+ except UserNotFoundError:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=f"User with username '{username}' not found.",
+ )
+
+ try:
+ await remove_user_from_workspace(
+ asession=asession, user_db=user_to_remove, workspace_db=workspace_db
+ )
+ return MessageResponse(
+ message=f"User '{username}' successfully removed from workspace '{workspace_db.workspace_name}'."
+ )
+ except UserNotFoundInWorkspaceError:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=f"User '{username}' is not a member of workspace '{workspace_db.workspace_name}'.",
+ )
+ except WorkspaceNotFoundError:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=f"Workspace with ID {workspace_id} not found.",
+ )
diff --git a/backend/app/workspaces/schemas.py b/backend/app/workspaces/schemas.py
index 512a78b..2e08fcd 100644
--- a/backend/app/workspaces/schemas.py
+++ b/backend/app/workspaces/schemas.py
@@ -118,3 +118,17 @@ class WorkspaceInviteResponse(BaseModel):
user_exists: bool
model_config = ConfigDict(from_attributes=True)
+
+
+class WorkspaceUserResponse(BaseModel):
+ """Pydantic model for workspace user information."""
+
+ user_id: int
+ username: str
+ first_name: str
+ last_name: str
+ role: UserRoles
+ is_default_workspace: bool
+ created_datetime_utc: datetime
+
+ model_config = ConfigDict(from_attributes=True)
diff --git a/backend/migrations/versions/949c9fc0461d_workspace_relationship.py b/backend/migrations/versions/949c9fc0461d_workspace_relationship.py
deleted file mode 100644
index d04edf2..0000000
--- a/backend/migrations/versions/949c9fc0461d_workspace_relationship.py
+++ /dev/null
@@ -1,36 +0,0 @@
-"""Workspace relationship
-
-Revision ID: 949c9fc0461d
-Revises: 977e7e73ce06
-Create Date: 2025-04-21 21:20:56.282928
-
-"""
-
-from typing import Sequence, Union
-
-import sqlalchemy as sa
-from alembic import op
-
-# revision identifiers, used by Alembic.
-revision: str = "949c9fc0461d"
-down_revision: Union[str, None] = "977e7e73ce06"
-branch_labels: Union[str, Sequence[str], None] = None
-depends_on: Union[str, Sequence[str], None] = None
-
-
-def upgrade() -> None:
- # ### commands auto generated by Alembic - please adjust! ###
- op.add_column(
- "experiments_base", sa.Column("workspace_id", sa.Integer(), nullable=False)
- )
- op.create_foreign_key(
- None, "experiments_base", "workspace", ["workspace_id"], ["workspace_id"]
- )
- # ### end Alembic commands ###
-
-
-def downgrade() -> None:
- # ### commands auto generated by Alembic - please adjust! ###
- op.drop_constraint(None, "experiments_base", type_="foreignkey")
- op.drop_column("experiments_base", "workspace_id")
- # ### end Alembic commands ###
diff --git a/backend/migrations/versions/977e7e73ce06_workspace_model.py b/backend/migrations/versions/977e7e73ce06_workspace_model.py
deleted file mode 100644
index bb4ff38..0000000
--- a/backend/migrations/versions/977e7e73ce06_workspace_model.py
+++ /dev/null
@@ -1,71 +0,0 @@
-"""Workspace model
-
-Revision ID: 977e7e73ce06
-Revises: ba1bf29910f5
-Create Date: 2025-04-20 20:17:32.839934
-
-"""
-
-from typing import Sequence, Union
-
-import sqlalchemy as sa
-from alembic import op
-
-# revision identifiers, used by Alembic.
-revision: str = "977e7e73ce06"
-down_revision: Union[str, None] = "ba1bf29910f5"
-branch_labels: Union[str, Sequence[str], None] = None
-depends_on: Union[str, Sequence[str], None] = None
-
-
-def upgrade() -> None:
- # ### commands auto generated by Alembic - please adjust! ###
- op.create_table(
- "workspace",
- sa.Column("api_daily_quota", sa.Integer(), nullable=True),
- sa.Column("api_key_first_characters", sa.String(length=5), nullable=True),
- sa.Column(
- "api_key_updated_datetime_utc", sa.DateTime(timezone=True), nullable=True
- ),
- sa.Column("content_quota", sa.Integer(), nullable=True),
- sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False),
- sa.Column("hashed_api_key", sa.String(length=96), nullable=True),
- sa.Column("updated_datetime_utc", sa.DateTime(timezone=True), nullable=False),
- sa.Column("workspace_id", sa.Integer(), nullable=False),
- sa.Column("workspace_name", sa.String(), nullable=False),
- sa.Column("is_default", sa.Boolean(), nullable=False),
- sa.PrimaryKeyConstraint("workspace_id"),
- sa.UniqueConstraint("hashed_api_key"),
- sa.UniqueConstraint("workspace_name"),
- )
- op.create_table(
- "user_workspace",
- sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False),
- sa.Column(
- "default_workspace",
- sa.Boolean(),
- server_default=sa.text("false"),
- nullable=False,
- ),
- sa.Column("updated_datetime_utc", sa.DateTime(timezone=True), nullable=False),
- sa.Column("user_id", sa.Integer(), nullable=False),
- sa.Column(
- "user_role",
- sa.Enum("ADMIN", "READ_ONLY", name="userroles", native_enum=False),
- nullable=False,
- ),
- sa.Column("workspace_id", sa.Integer(), nullable=False),
- sa.ForeignKeyConstraint(["user_id"], ["users.user_id"], ondelete="CASCADE"),
- sa.ForeignKeyConstraint(
- ["workspace_id"], ["workspace.workspace_id"], ondelete="CASCADE"
- ),
- sa.PrimaryKeyConstraint("user_id", "workspace_id"),
- )
- # ### end Alembic commands ###
-
-
-def downgrade() -> None:
- # ### commands auto generated by Alembic - please adjust! ###
- op.drop_table("user_workspace")
- op.drop_table("workspace")
- # ### end Alembic commands ###
diff --git a/backend/migrations/versions/d9f7a309944e_workspace_model.py b/backend/migrations/versions/d9f7a309944e_workspace_model.py
new file mode 100644
index 0000000..1482764
--- /dev/null
+++ b/backend/migrations/versions/d9f7a309944e_workspace_model.py
@@ -0,0 +1,72 @@
+"""Workspace model
+
+Revision ID: d9f7a309944e
+Revises: 5c15463fda65
+Create Date: 2025-04-30 23:23:21.122138
+
+"""
+from typing import Sequence, Union
+
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision: str = 'd9f7a309944e'
+down_revision: Union[str, None] = '5c15463fda65'
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.create_table('workspace',
+ sa.Column('api_daily_quota', sa.Integer(), nullable=True),
+ sa.Column('api_key_first_characters', sa.String(length=5), nullable=True),
+ sa.Column('api_key_updated_datetime_utc', sa.DateTime(timezone=True), nullable=True),
+ sa.Column('content_quota', sa.Integer(), nullable=True),
+ sa.Column('created_datetime_utc', sa.DateTime(timezone=True), nullable=False),
+ sa.Column('hashed_api_key', sa.String(length=96), nullable=True),
+ sa.Column('updated_datetime_utc', sa.DateTime(timezone=True), nullable=False),
+ sa.Column('workspace_id', sa.Integer(), nullable=False),
+ sa.Column('workspace_name', sa.String(), nullable=False),
+ sa.Column('is_default', sa.Boolean(), nullable=False),
+ sa.PrimaryKeyConstraint('workspace_id'),
+ sa.UniqueConstraint('hashed_api_key'),
+ sa.UniqueConstraint('workspace_name')
+ )
+ op.create_table('pending_invitations',
+ sa.Column('invitation_id', sa.Integer(), nullable=False),
+ sa.Column('email', sa.String(), nullable=False),
+ sa.Column('workspace_id', sa.Integer(), nullable=False),
+ sa.Column('role', sa.Enum('ADMIN', 'READ_ONLY', name='userroles', native_enum=False), nullable=False),
+ sa.Column('inviter_id', sa.Integer(), nullable=False),
+ sa.Column('created_datetime_utc', sa.DateTime(timezone=True), nullable=False),
+ sa.ForeignKeyConstraint(['inviter_id'], ['users.user_id'], ),
+ sa.ForeignKeyConstraint(['workspace_id'], ['workspace.workspace_id'], ondelete='CASCADE'),
+ sa.PrimaryKeyConstraint('invitation_id')
+ )
+ op.create_table('user_workspace',
+ sa.Column('created_datetime_utc', sa.DateTime(timezone=True), nullable=False),
+ sa.Column('default_workspace', sa.Boolean(), server_default=sa.text('false'), nullable=False),
+ sa.Column('updated_datetime_utc', sa.DateTime(timezone=True), nullable=False),
+ sa.Column('user_id', sa.Integer(), nullable=False),
+ sa.Column('user_role', sa.Enum('ADMIN', 'READ_ONLY', name='userroles', native_enum=False), nullable=False),
+ sa.Column('workspace_id', sa.Integer(), nullable=False),
+ sa.ForeignKeyConstraint(['user_id'], ['users.user_id'], ondelete='CASCADE'),
+ sa.ForeignKeyConstraint(['workspace_id'], ['workspace.workspace_id'], ondelete='CASCADE'),
+ sa.PrimaryKeyConstraint('user_id', 'workspace_id')
+ )
+ op.add_column('experiments_base', sa.Column('workspace_id', sa.Integer(), nullable=False))
+ op.create_foreign_key(None, 'experiments_base', 'workspace', ['workspace_id'], ['workspace_id'])
+ # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_constraint(None, 'experiments_base', type_='foreignkey')
+ op.drop_column('experiments_base', 'workspace_id')
+ op.drop_table('user_workspace')
+ op.drop_table('pending_invitations')
+ op.drop_table('workspace')
+ # ### end Alembic commands ###
From 64768d6af15ca7bb6aae7444f22d9ce1e76a7dcd Mon Sep 17 00:00:00 2001
From: Jay Prakash <0freerunning@gmail.com>
Date: Thu, 1 May 2025 13:08:56 +0300
Subject: [PATCH 07/74] New frontend
---
backend/app/users/routers.py | 31 -
backend/app/workspaces/models.py | 1 +
.../src/app/(protected)/integration/api.ts | 25 -
.../integration/components/ApiKeyDisplay.tsx | 135 ----
.../src/app/(protected)/integration/page.tsx | 26 -
.../app/(protected)/workspace/create/page.tsx | 131 ----
.../app/(protected)/workspace/invite/page.tsx | 193 ------
.../src/app/(protected)/workspace/page.tsx | 116 ----
.../src/app/(protected)/workspace/types.ts | 50 --
.../workspaces/[workspaceId]/page.tsx | 589 ++++++++++++++++++
.../[workspaceId]/users/invite/page.tsx | 244 ++++++++
.../src/app/(protected)/workspaces/page.tsx | 243 ++++++++
.../src/app/(protected)/workspaces/types.ts | 22 +
frontend/src/components/WorkspaceSelector.tsx | 142 -----
frontend/src/components/app-sidebar.tsx | 121 +---
.../components/create-workspace-dialog.tsx | 152 +++++
frontend/src/components/ui/alert-dialog.tsx | 143 +++++
frontend/src/components/ui/tabs.tsx | 55 ++
.../src/components/workspace-switcher.tsx | 65 +-
frontend/src/utils/api.ts | 186 ++++--
frontend/src/utils/auth.tsx | 183 ++++--
21 files changed, 1778 insertions(+), 1075 deletions(-)
delete mode 100644 frontend/src/app/(protected)/integration/api.ts
delete mode 100644 frontend/src/app/(protected)/integration/components/ApiKeyDisplay.tsx
delete mode 100644 frontend/src/app/(protected)/integration/page.tsx
delete mode 100644 frontend/src/app/(protected)/workspace/create/page.tsx
delete mode 100644 frontend/src/app/(protected)/workspace/invite/page.tsx
delete mode 100644 frontend/src/app/(protected)/workspace/page.tsx
delete mode 100644 frontend/src/app/(protected)/workspace/types.ts
create mode 100644 frontend/src/app/(protected)/workspaces/[workspaceId]/page.tsx
create mode 100644 frontend/src/app/(protected)/workspaces/[workspaceId]/users/invite/page.tsx
create mode 100644 frontend/src/app/(protected)/workspaces/page.tsx
create mode 100644 frontend/src/app/(protected)/workspaces/types.ts
delete mode 100644 frontend/src/components/WorkspaceSelector.tsx
create mode 100644 frontend/src/components/create-workspace-dialog.tsx
create mode 100644 frontend/src/components/ui/alert-dialog.tsx
create mode 100644 frontend/src/components/ui/tabs.tsx
diff --git a/backend/app/users/routers.py b/backend/app/users/routers.py
index 443dde3..5f05144 100644
--- a/backend/app/users/routers.py
+++ b/backend/app/users/routers.py
@@ -153,34 +153,3 @@ async def get_user(
return UserRetrieve.model_validate(user_db)
-
-@router.put("/rotate-key", response_model=KeyResponse)
-async def get_new_api_key(
- user_db: Annotated[UserDB, Depends(get_verified_user)],
- asession: AsyncSession = Depends(get_async_session),
-) -> KeyResponse | None:
- """
- Generate a new API key for the requester's account. Takes a user object,
- generates a new key, replaces the old one in the database, and returns
- a user object with the new key.
- """
-
- new_api_key = generate_key()
-
- try:
- # this is neccesarry to attach the user_db to the session
- asession.add(user_db)
- await update_user_api_key(
- user_db=user_db,
- new_api_key=new_api_key,
- asession=asession,
- )
- return KeyResponse(
- username=user_db.username,
- new_api_key=new_api_key,
- )
- except SQLAlchemyError as e:
- logger.error(f"Error updating user api key: {e}")
- raise HTTPException(
- status_code=500, detail="Error updating user api key"
- ) from e
diff --git a/backend/app/workspaces/models.py b/backend/app/workspaces/models.py
index 34dec0a..23d601c 100644
--- a/backend/app/workspaces/models.py
+++ b/backend/app/workspaces/models.py
@@ -167,6 +167,7 @@ async def get_user_by_user_id(
user_id: int, asession: AsyncSession
) -> "UserDB":
"""Get a user by user ID."""
+ from ..users.models import UserDB
stmt = select(UserDB).where(UserDB.user_id == user_id)
result = await asession.execute(stmt)
try:
diff --git a/frontend/src/app/(protected)/integration/api.ts b/frontend/src/app/(protected)/integration/api.ts
deleted file mode 100644
index 47fe784..0000000
--- a/frontend/src/app/(protected)/integration/api.ts
+++ /dev/null
@@ -1,25 +0,0 @@
-import api from "@/utils/api";
-
-const getUser = async (token: string | null) => {
- const response = await api.get("/user/", {
- headers: {
- Authorization: `Bearer ${token}`,
- },
- });
- return response.data;
-};
-
-const rotateAPIKey = async (token: string | null) => {
- const response = await api.put(
- "/user/rotate-key",
- {},
- {
- headers: {
- Authorization: `Bearer ${token}`,
- },
- }
- );
- return response.data;
-};
-
-export { getUser, rotateAPIKey };
diff --git a/frontend/src/app/(protected)/integration/components/ApiKeyDisplay.tsx b/frontend/src/app/(protected)/integration/components/ApiKeyDisplay.tsx
deleted file mode 100644
index 5bc03da..0000000
--- a/frontend/src/app/(protected)/integration/components/ApiKeyDisplay.tsx
+++ /dev/null
@@ -1,135 +0,0 @@
-"use client";
-
-import { useState } from "react";
-import { Button } from "@/components/ui/button";
-import {
- Dialog,
- DialogContent,
- DialogDescription,
- DialogTitle,
-} from "@/components/ui/dialog";
-import { KeyRound, Copy } from "lucide-react";
-import { useToast } from "@/hooks/use-toast";
-import { useAuth } from "@/utils/auth";
-import { useEffect } from "react";
-import Hourglass from "@/components/Hourglass";
-import { getUser, rotateAPIKey } from "../api";
-
-export function ApiKeyDisplay() {
- const { token } = useAuth();
-
- const [apiKey, setApiKey] = useState("");
- const [isModalOpen, setIsModalOpen] = useState(false);
- const [newKey, setNewKey] = useState("");
- const [isRefreshing, setIsRefreshing] = useState(false);
- const [isLoading, setIsLoading] = useState(false);
-
- const { toast } = useToast();
-
- useEffect(() => {
- setIsLoading(true);
- getUser(token)
- .then((data) => {
- setApiKey(data.api_key_first_characters);
- })
- .catch((error: Error) => {
- console.log(error);
- })
- .finally(() => {
- setIsLoading(false);
- });
- }, [token]);
-
- const handleGenerateKey = async () => {
- setIsRefreshing(true);
- rotateAPIKey(token)
- .then((data) => {
- setNewKey(data.new_api_key);
- setIsModalOpen(true);
- })
- .catch((error: Error) => {
- console.log(error);
- toast({
- title: "Error",
- description: "Failed to generate new API key",
- variant: "destructive",
- });
- })
- .finally(() => {
- setIsRefreshing(false);
- });
- };
-
- const handleCopyKey = async () => {
- try {
- await navigator.clipboard.writeText(newKey);
- toast({
- title: "Copied!",
- description: "API key copied to clipboard",
- });
- } catch (error) {
- toast({
- title: "Error",
- description: "Failed to copy API key",
- variant: "destructive",
- });
- }
- };
-
- const handleConfirm = () => {
- setApiKey(newKey);
- setIsModalOpen(false);
- toast({
- title: "Success",
- description: "API key has been updated",
- });
- };
-
- return isLoading ? (
-
- ) : (
-
-
-
-
-
- {apiKey.slice(0, 5)}
- {"•".repeat(27)}
-
-
-
- {isRefreshing ? "Generating..." : "Recreate Key"}
-
-
-
-
-
- New API Key Generated
-
- Make sure to copy your new API key. You won't be able to see it
- again!
-
-
-
- {newKey}
-
-
-
-
- Copy
-
-
- I've saved my API key
-
-
-
-
-
-
- );
-}
diff --git a/frontend/src/app/(protected)/integration/page.tsx b/frontend/src/app/(protected)/integration/page.tsx
deleted file mode 100644
index 6922061..0000000
--- a/frontend/src/app/(protected)/integration/page.tsx
+++ /dev/null
@@ -1,26 +0,0 @@
-import {
- Card,
- CardContent,
- CardDescription,
- CardHeader,
- CardTitle,
-} from "@/components/ui/card";
-import { ApiKeyDisplay } from "./components/ApiKeyDisplay";
-
-export default function ApiKeyPage() {
- return (
-
-
-
- API Key Management
-
- View and manage your API keys securely
-
-
-
-
-
-
-
- );
-}
diff --git a/frontend/src/app/(protected)/workspace/create/page.tsx b/frontend/src/app/(protected)/workspace/create/page.tsx
deleted file mode 100644
index 99aff9e..0000000
--- a/frontend/src/app/(protected)/workspace/create/page.tsx
+++ /dev/null
@@ -1,131 +0,0 @@
-"use client";
-
-import { zodResolver } from "@hookform/resolvers/zod";
-import { useForm } from "react-hook-form";
-import { z } from "zod";
-import { Button } from "@/components/catalyst/button";
-import { Input } from "@/components/catalyst/input";
-import { useAuth } from "@/utils/auth";
-import { apiCalls } from "@/utils/api";
-import { useToast } from "@/hooks/use-toast";
-import { useRouter } from "next/navigation";
-import { useState } from "react";
-import {
- Fieldset,
- Field,
- FieldGroup,
- Label,
- Description,
-} from "@/components/catalyst/fieldset";
-import { Card, CardContent, CardHeader, CardTitle, CardDescription } from "@/components/ui/card";
-import { BuildingOfficeIcon } from "@heroicons/react/20/solid";
-
-const formSchema = z.object({
- workspace_name: z.string().min(3, {
- message: "Workspace name must be at least 3 characters",
- }),
-});
-
-type FormValues = z.infer;
-
-export default function CreateWorkspacePage() {
- const { token, switchWorkspace } = useAuth();
- const { toast } = useToast();
- const router = useRouter();
- const [isSubmitting, setIsSubmitting] = useState(false);
-
- const form = useForm({
- resolver: zodResolver(formSchema),
- defaultValues: {
- workspace_name: "",
- },
- });
-
- const onSubmit = async (data: FormValues) => {
- if (!token) {
- toast({
- title: "Error",
- description: "You must be logged in to create a workspace",
- variant: "destructive",
- });
- return;
- }
-
- setIsSubmitting(true);
- try {
- const response = await apiCalls.createWorkspace(token, data);
-
- await switchWorkspace(response.workspace_name);
-
- toast({
- title: "Success",
- description: `Workspace "${response.workspace_name}" created and activated!`,
- });
-
- router.push("/workspace");
- } catch (error: any) {
- toast({
- title: "Error",
- description: error.message || "Failed to create workspace",
- variant: "destructive",
- });
- } finally {
- setIsSubmitting(false);
- }
- };
-
- return (
-
-
-
-
-
-
- Create new workspace
-
- Create a new workspace to organize your experiments and team members
-
-
-
-
-
-
-
-
-
- );
-}
diff --git a/frontend/src/app/(protected)/workspace/invite/page.tsx b/frontend/src/app/(protected)/workspace/invite/page.tsx
deleted file mode 100644
index f610ffc..0000000
--- a/frontend/src/app/(protected)/workspace/invite/page.tsx
+++ /dev/null
@@ -1,193 +0,0 @@
-"use client";
-
-import { useState } from "react";
-import { useAuth } from "@/utils/auth";
-import { apiCalls } from "@/utils/api";
-import { useToast } from "@/hooks/use-toast";
-import { z } from "zod";
-import { zodResolver } from "@hookform/resolvers/zod";
-import { useForm } from "react-hook-form";
-import { Button } from "@/components/catalyst/button";
-import { Heading } from "@/components/catalyst/heading";
-import { Input } from "@/components/catalyst/input";
-import {
- Card,
- CardContent,
- CardDescription,
- CardFooter,
- CardHeader,
- CardTitle,
-} from "@/components/ui/card";
-import {
- Fieldset,
- Field,
- FieldGroup,
- Label,
- Description,
-} from "@/components/catalyst/fieldset";
-import { Radio, RadioField, RadioGroup } from "@/components/catalyst/radio";
-import { Badge } from "@/components/ui/badge";
-import { EnvelopeIcon, UserPlusIcon } from "@heroicons/react/20/solid";
-
-const inviteSchema = z.object({
- email: z.string().email({
- message: "Please enter a valid email address",
- }),
- role: z.enum(["ADMIN", "EDITOR", "VIEWER"], {
- required_error: "Please select a role",
- }),
-});
-
-type InviteFormValues = z.infer;
-
-export default function InviteUsersPage() {
- const { token, currentWorkspace } = useAuth();
- const { toast } = useToast();
- const [isSubmitting, setIsSubmitting] = useState(false);
- const [invitedUsers, setInvitedUsers] = useState<{ email: string; role: string; exists: boolean }[]>([]);
-
- const {
- register,
- handleSubmit,
- reset,
- formState: { errors },
- setValue,
- watch,
- } = useForm({
- resolver: zodResolver(inviteSchema),
- defaultValues: {
- email: "",
- role: "VIEWER",
- },
- });
-
- const roleValue = watch("role");
-
- const onSubmit = async (data: InviteFormValues) => {
- if (!token || !currentWorkspace) {
- toast({
- title: "Error",
- description: "You must be logged in and have a workspace selected",
- variant: "destructive",
- });
- return;
- }
-
- setIsSubmitting(true);
- try {
- const response = await apiCalls.inviteUserToWorkspace(token, {
- email: data.email,
- role: data.role,
- workspace_name: currentWorkspace.workspace_name,
- });
-
- // Add to invited users list
- setInvitedUsers([
- ...invitedUsers,
- {
- email: data.email,
- role: data.role,
- exists: response.user_exists,
- },
- ]);
-
- // Reset form
- reset();
-
- toast({
- title: "Success",
- description: `Invitation sent to ${data.email}`,
- });
- } catch (error: any) {
- toast({
- title: "Error",
- description: error.message || "Failed to send invitation",
- variant: "destructive",
- });
- } finally {
- setIsSubmitting(false);
- }
- };
-
- return (
-
-
- Invite Team Members
-
-
-
-
-
-
- Send Invitation
-
- Invite users to join your workspace: {currentWorkspace?.workspace_name}
-
-
-
-
-
-
-
-
-
- );
-}
diff --git a/frontend/src/app/(protected)/workspace/page.tsx b/frontend/src/app/(protected)/workspace/page.tsx
deleted file mode 100644
index 70b4d4c..0000000
--- a/frontend/src/app/(protected)/workspace/page.tsx
+++ /dev/null
@@ -1,116 +0,0 @@
-"use client";
-
-import { useEffect } from "react";
-import { useRouter } from "next/navigation";
-import { useAuth } from "@/utils/auth";
-import { Heading } from "@/components/catalyst/heading";
-import { Card, CardContent, CardHeader, CardTitle, CardDescription } from "@/components/ui/card";
-import { Plus, Users, Settings, Key } from "lucide-react";
-import { Button } from "@/components/catalyst/button";
-
-export default function WorkspacePage() {
- const { currentWorkspace } = useAuth();
- const router = useRouter();
-
- if (!currentWorkspace) {
- return (
-
-
-
-
- Something went wrong. Please try again later.
-
-
-
-
- );
- }
- console.log("Current Workspace:", currentWorkspace);
-
- return (
-
-
- {currentWorkspace.workspace_name}
-
-
-
-
-
- Workspace Information
-
-
-
-
-
Name:
- {currentWorkspace.workspace_name}
-
-
-
API Quota:
- {currentWorkspace.api_daily_quota.toLocaleString()} calls/day
-
-
-
Experiment Quota:
- {currentWorkspace.content_quota.toLocaleString()} experiments
-
-
-
Created:
-
- {new Date(currentWorkspace.created_datetime_utc).toLocaleDateString()}
-
-
-
-
-
-
-
-
- API Key
-
-
-
-
-
-
- {currentWorkspace.api_key_first_characters}
- {"•".repeat(27)}
-
-
-
router.push('/integration')}>
- Manage API Keys
-
-
-
- Use this API key to authenticate your API requests. Keep it secret and secure.
-
-
-
-
-
-
-
router.push('/workspace/invite')}>
-
-
- Invite Team Members
-
- Invite colleagues to join your workspace
-
-
-
-
-
-
-
router.push('/workspace/create')}>
-
-
- Create New Workspace
-
- Create a new workspace for different projects
-
-
-
-
-
-
-
- );
-}
diff --git a/frontend/src/app/(protected)/workspace/types.ts b/frontend/src/app/(protected)/workspace/types.ts
deleted file mode 100644
index 2d53c31..0000000
--- a/frontend/src/app/(protected)/workspace/types.ts
+++ /dev/null
@@ -1,50 +0,0 @@
-export enum UserRoles {
- ADMIN = "ADMIN",
- VIEWER = "VIEWER",
-}
-
-export interface Workspace {
- workspace_id: number;
- workspace_name: string;
- api_key_first_characters: string;
- api_key_updated_datetime_utc: string;
- api_daily_quota: number;
- content_quota: number;
- created_datetime_utc: string;
- updated_datetime_utc: string;
- is_default: boolean;
-}
-
-export interface WorkspaceCreate {
- workspace_name: string;
- api_daily_quota?: number;
- content_quota?: number;
-}
-
-export interface WorkspaceUpdate {
- workspace_name?: string;
- api_daily_quota?: number;
- content_quota?: number;
-}
-
-export interface WorkspaceKeyResponse {
- new_api_key: string;
- workspace_name: string;
-}
-
-export interface WorkspaceInvite {
- email: string;
- role: UserRoles;
- workspace_name: string;
-}
-
-export interface WorkspaceInviteResponse {
- message: string;
- email: string;
- workspace_name: string;
- user_exists: boolean;
-}
-
-export interface WorkspaceSwitch {
- workspace_name: string;
-}
diff --git a/frontend/src/app/(protected)/workspaces/[workspaceId]/page.tsx b/frontend/src/app/(protected)/workspaces/[workspaceId]/page.tsx
new file mode 100644
index 0000000..a3d2a96
--- /dev/null
+++ b/frontend/src/app/(protected)/workspaces/[workspaceId]/page.tsx
@@ -0,0 +1,589 @@
+// Path: frontend/src/app/(protected)/workspaces/[workspaceId]/page.tsx
+"use client";
+
+import { useState, useEffect } from "react";
+import { useParams, useRouter } from "next/navigation";
+import { useAuth } from "@/utils/auth";
+import { apiCalls } from "@/utils/api";
+import { useToast } from "@/hooks/use-toast";
+
+import { Building, ChevronLeftIcon, Users, Key, Copy } from "lucide-react";
+import { Button } from "@/components/ui/button";
+import {
+ Card,
+ CardContent,
+ CardDescription,
+ CardFooter,
+ CardHeader,
+ CardTitle,
+} from "@/components/ui/card";
+import {
+ Breadcrumb,
+ BreadcrumbItem,
+ BreadcrumbLink,
+ BreadcrumbList,
+ BreadcrumbPage,
+ BreadcrumbSeparator,
+} from "@/components/ui/breadcrumb";
+import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
+import Hourglass from "@/components/Hourglass";
+import { Separator } from "@/components/ui/separator";
+import {
+ AlertDialog,
+ AlertDialogAction,
+ AlertDialogCancel,
+ AlertDialogContent,
+ AlertDialogDescription,
+ AlertDialogFooter,
+ AlertDialogHeader,
+ AlertDialogTitle,
+ AlertDialogTrigger,
+} from "@/components/ui/alert-dialog";
+import {
+ Dialog,
+ DialogContent,
+ DialogDescription,
+ DialogFooter,
+ DialogHeader,
+ DialogTitle,
+} from "@/components/ui/dialog";
+
+export default function WorkspaceDetailPage() {
+ const params = useParams();
+ const router = useRouter();
+ const { token, currentWorkspace, fetchWorkspaces, switchWorkspace } =
+ useAuth();
+ const { toast } = useToast();
+
+ const [workspace, setWorkspace] = useState(null);
+ const [workspaceUsers, setWorkspaceUsers] = useState([]);
+ const [isLoading, setIsLoading] = useState(true);
+ const [isRotatingKey, setIsRotatingKey] = useState(false);
+ const [newApiKey, setNewApiKey] = useState(null);
+ const [isApiKeyDialogOpen, setIsApiKeyDialogOpen] = useState(false);
+
+ const workspaceId = Number(params.workspaceId);
+
+ useEffect(() => {
+ const loadWorkspaceData = async () => {
+ if (!token) return;
+
+ setIsLoading(true);
+ try {
+ // Fetch workspace details
+ const workspaceData = await apiCalls.getWorkspaceById(
+ token,
+ workspaceId
+ );
+ setWorkspace(workspaceData);
+
+ // Fetch workspace users
+ const usersData = await apiCalls.getWorkspaceUsers(token, workspaceId);
+ setWorkspaceUsers(usersData);
+ } catch (error) {
+ console.error("Error loading workspace data:", error);
+ toast({
+ title: "Error",
+ description: "Failed to load workspace details",
+ variant: "destructive",
+ });
+ } finally {
+ setIsLoading(false);
+ }
+ };
+
+ loadWorkspaceData();
+ }, [token, workspaceId, toast]);
+
+ const handleRotateApiKey = async () => {
+ if (!token || !workspace) return;
+
+ setIsRotatingKey(true);
+ try {
+ // Need to switch to this workspace first if it's not the current one
+ if (currentWorkspace?.workspace_id !== workspaceId) {
+ await switchWorkspace(workspace.workspace_name);
+ }
+
+ const result = await apiCalls.rotateWorkspaceApiKey(token);
+ setNewApiKey(result.new_api_key);
+ setIsApiKeyDialogOpen(true);
+
+ // Refresh workspaces data
+ await fetchWorkspaces();
+
+ // Refresh workspace details
+ const updatedWorkspace = await apiCalls.getWorkspaceById(
+ token,
+ workspaceId
+ );
+ setWorkspace(updatedWorkspace);
+
+ toast({
+ title: "Success",
+ description: "API key rotated successfully",
+ variant: "success",
+ });
+ } catch (error: any) {
+ toast({
+ title: "Error",
+ description: error.message || "Failed to rotate API key",
+ variant: "destructive",
+ });
+ } finally {
+ setIsRotatingKey(false);
+ }
+ };
+
+ const handleCopyApiKey = () => {
+ if (!newApiKey) return;
+
+ navigator.clipboard.writeText(newApiKey);
+ toast({
+ title: "Copied",
+ description: "API key copied to clipboard",
+ });
+ };
+
+ const handleRemoveUser = async (username: string) => {
+ if (!token) return;
+
+ try {
+ await apiCalls.removeUserFromWorkspace(token, workspaceId, username);
+
+ // Refresh user list
+ const updatedUsers = await apiCalls.getWorkspaceUsers(token, workspaceId);
+ setWorkspaceUsers(updatedUsers);
+
+ toast({
+ title: "Success",
+ description: `${username} has been removed from the workspace`,
+ variant: "success",
+ });
+ } catch (error: any) {
+ toast({
+ title: "Error",
+ description: error.message || "Failed to remove user",
+ variant: "destructive",
+ });
+ }
+ };
+
+ if (isLoading) {
+ return (
+
+
+
+
+ Loading workspace details...
+
+
+
+ );
+ }
+
+ if (!workspace) {
+ return (
+
+
+
Workspace Not Found
+
+ The requested workspace could not be found
+
+
router.push("/workspaces")}>
+
+ Back to Workspaces
+
+
+
+ );
+ }
+
+ return (
+ <>
+
+
+
+
+
+ Home
+
+
+
+ Workspaces
+
+
+
+ {workspace.workspace_name}
+
+
+
+
+
+
+
+
+
+
+
+
+ {workspace.workspace_name}
+
+
+ Workspace ID: {workspace.workspace_id}
+
+
+
+
router.push("/workspaces")}>
+
+ Back to Workspaces
+
+
+
+
+
+ Overview
+ Users
+ API
+
+
+
+
+
+ Workspace Overview
+
+ Summary information about this workspace
+
+
+
+
+
+
+ Workspace Details
+
+
+
Name:
+
{workspace.workspace_name}
+
+
ID:
+
{workspace.workspace_id}
+
+
Created:
+
+ {new Date(
+ workspace.created_datetime_utc
+ ).toLocaleDateString()}
+
+
+
Last Updated:
+
+ {new Date(
+ workspace.updated_datetime_utc
+ ).toLocaleDateString()}
+
+
+
+ API Daily Quota:
+
+
{workspace.api_daily_quota} calls/day
+
+
+ Content Quota:
+
+
{workspace.content_quota} experiments
+
+
+
+
+
+ API Configuration
+
+
+
+ API Key Prefix:
+
+
+ {workspace.api_key_first_characters}•••••
+
+
+
+ Key Last Rotated:
+
+
+ {new Date(
+ workspace.api_key_updated_datetime_utc
+ ).toLocaleDateString()}
+
+
+
+
+
+ {currentWorkspace?.workspace_id !== workspace.workspace_id && (
+
+ {
+ try {
+ setIsLoading(true);
+ await switchWorkspace(workspace.workspace_name);
+ toast({
+ title: "Success",
+ description: `Switched to ${workspace.workspace_name} workspace`,
+ });
+ } catch (error) {
+ console.error("Error switching workspace:", error);
+ toast({
+ title: "Error",
+ description: "Failed to switch workspace",
+ variant: "destructive",
+ });
+ } finally {
+ setIsLoading(false);
+ }
+ }}
+ >
+ Switch to this workspace
+
+
+ )}
+
+
+
+
+
+
+
+ Workspace Users
+ {!workspace.is_default && (
+
+ router.push(`/workspaces/${workspaceId}/users/invite`)
+ }
+ >
+
+ Invite User
+
+ )}
+
+
+ {workspace.is_default
+ ? "This is a default workspace. User management is restricted."
+ : "Manage users who have access to this workspace"}
+
+
+
+ {workspaceUsers.length === 0 ? (
+
+
+
No users found
+
+ {workspace.is_default
+ ? "Default workspaces automatically include all users."
+ : "Invite users to collaborate in this workspace"}
+
+
+ ) : (
+
+
+
User
+
Role
+
Joined
+
Actions
+
+ {workspaceUsers.map((user) => {
+ // Find the current user to determine if they have admin rights
+ const isCurrentUserAdmin =
+ workspaceUsers.find(
+ (u) => u.username === currentWorkspace?.username
+ )?.role === "admin";
+
+ return (
+
+
+
+ {user.first_name} {user.last_name}
+
+
+ {user.username}
+
+
+
+ {user.role.toLowerCase()}
+ {user.is_default_workspace && (
+
+ Default
+
+ )}
+
+
+ {new Date(
+ user.created_datetime_utc
+ ).toLocaleDateString()}
+
+
+ {!workspace.is_default && isCurrentUserAdmin && (
+
+
+
+ Remove
+
+
+
+
+
+ Remove user?
+
+
+ Are you sure you want to remove{" "}
+ {user.first_name} {user.last_name} from
+ this workspace?
+
+
+
+ {
+ e.preventDefault();
+ e.stopPropagation();
+ }}
+ >
+ Cancel
+
+
+ handleRemoveUser(user.username)
+ }
+ >
+ Remove
+
+
+
+
+ )}
+
+
+ );
+ })}
+
+ )}
+
+
+
+
+
+
+
+
+ API Configuration
+
+
+ {isRotatingKey ? "Rotating..." : "Rotate API Key"}
+
+
+
+ Manage API settings for this workspace
+
+
+
+
+
Current API Key
+
+ {workspace.api_key_first_characters}
+ •••••••••••••••••••••••••••
+
+
+ Last updated on{" "}
+ {new Date(
+ workspace.api_key_updated_datetime_utc
+ ).toLocaleString()}
+
+
+
+
+
API Usage Limits
+
+
+
+ Daily Quota
+
+
+ {workspace.api_daily_quota}
+
+
+ API calls per day
+
+
+
+
+ Content Quota
+
+
+ {workspace.content_quota}
+
+
+ Experiments
+
+
+
+
+
+
+
+
+
+ About API Key Rotation
+
+
+ When you rotate your API key, the old key will be
+ immediately invalidated. Any services or applications using
+ the old key will need to be updated with the new key. Make
+ sure to copy and save your new key as it will only be shown
+ once.
+
+
+
+
+
+
+
+
+ {/* Dialog for showing the new API key */}
+
+
+
+ New API Key Generated
+
+ Make sure to copy your new API key. You won't be able to see it
+ again!
+
+
+
+
+ {newApiKey}
+
+
+
+
+
+ Copy to Clipboard
+
+ setIsApiKeyDialogOpen(false)}>
+ I've Saved My Key
+
+
+
+
+ >
+ );
+}
diff --git a/frontend/src/app/(protected)/workspaces/[workspaceId]/users/invite/page.tsx b/frontend/src/app/(protected)/workspaces/[workspaceId]/users/invite/page.tsx
new file mode 100644
index 0000000..d2d2644
--- /dev/null
+++ b/frontend/src/app/(protected)/workspaces/[workspaceId]/users/invite/page.tsx
@@ -0,0 +1,244 @@
+"use client";
+
+import { useState } from "react";
+import { useParams, useRouter } from "next/navigation";
+import { useAuth } from "@/utils/auth";
+import { apiCalls } from "@/utils/api";
+import { useToast } from "@/hooks/use-toast";
+import { z } from "zod";
+import { useForm } from "react-hook-form";
+import { zodResolver } from "@hookform/resolvers/zod";
+
+import { ChevronLeftIcon, Send, Users } from "lucide-react";
+import { Button } from "@/components/ui/button";
+import {
+ Card,
+ CardContent,
+ CardDescription,
+ CardFooter,
+ CardHeader,
+ CardTitle,
+} from "@/components/ui/card";
+import {
+ Breadcrumb,
+ BreadcrumbItem,
+ BreadcrumbLink,
+ BreadcrumbList,
+ BreadcrumbPage,
+ BreadcrumbSeparator,
+} from "@/components/ui/breadcrumb";
+import {
+ Form,
+ FormControl,
+ FormDescription,
+ FormField,
+ FormItem,
+ FormLabel,
+ FormMessage,
+} from "@/components/ui/form";
+import { Input } from "@/components/ui/input";
+import {
+ Select,
+ SelectContent,
+ SelectItem,
+ SelectTrigger,
+ SelectValue,
+} from "@/components/ui/select";
+
+// Create schema for invitation form
+const InviteFormSchema = z.object({
+ email: z
+ .string()
+ .email("Please enter a valid email address"),
+ role: z
+ .string()
+ .refine(val => ["admin", "read_only"].includes(val), {
+ message: "Please select a valid role",
+ }),
+});
+
+type InviteFormValues = z.infer;
+
+export default function InviteUserPage() {
+ const params = useParams();
+ const router = useRouter();
+ const { token } = useAuth();
+ const { toast } = useToast();
+ const [isInviting, setIsInviting] = useState(false);
+ const [inviteSuccess, setInviteSuccess] = useState(false);
+
+ const workspaceId = Number(params.workspaceId);
+
+ // Initialize form
+ const form = useForm({
+ resolver: zodResolver(InviteFormSchema),
+ defaultValues: {
+ email: "",
+ role: "read_only", // Default to read-only
+ },
+ });
+
+ const onSubmit = async (data: InviteFormValues) => {
+ if (!token) return;
+
+ setIsInviting(true);
+ try {
+ // First need to get workspace name
+ const workspace = await apiCalls.getWorkspaceById(token, workspaceId);
+
+ // Send invitation
+ const response = await apiCalls.inviteUserToWorkspace(
+ token,
+ data.email,
+ workspace.workspace_name,
+ data.role
+ );
+
+ // Show success message
+ setInviteSuccess(true);
+ toast({
+ title: "Invitation sent",
+ description: `${data.email} has been invited to the workspace`,
+ variant: "success",
+ });
+
+ // Reset form
+ form.reset();
+ } catch (error: any) {
+ toast({
+ title: "Error",
+ description: error.message || "Failed to send invitation",
+ variant: "destructive",
+ });
+ } finally {
+ setIsInviting(false);
+ }
+ };
+
+ return (
+
+
+
+
+
+ Home
+
+
+
+ Workspaces
+
+
+
+ Workspace
+
+
+
+ Invite User
+
+
+
+
+
+
+
+
router.back()}>
+
+ Back
+
+
+
+
+
+ Invite User to Workspace
+
+ Send an invitation to join this workspace
+
+
+
+
+
+
+ {inviteSuccess && (
+
+
+ Invitation Sent Successfully
+
+
+ An email has been sent to the user with instructions to join the workspace.
+
+
+ )}
+
+
+ router.push(`/workspaces/${workspaceId}`)}
+ >
+ Return to Workspace
+
+
+
+
+ );
+}
\ No newline at end of file
diff --git a/frontend/src/app/(protected)/workspaces/page.tsx b/frontend/src/app/(protected)/workspaces/page.tsx
new file mode 100644
index 0000000..2e7b6bc
--- /dev/null
+++ b/frontend/src/app/(protected)/workspaces/page.tsx
@@ -0,0 +1,243 @@
+"use client";
+
+import React, { useState, useEffect } from "react";
+import { useAuth } from "@/utils/auth";
+import { Building, Plus, Settings } from "lucide-react";
+import { Button } from "@/components/ui/button";
+import {
+ Card,
+ CardContent,
+ CardDescription,
+ CardFooter,
+ CardHeader,
+ CardTitle,
+} from "@/components/ui/card";
+import {
+ Breadcrumb,
+ BreadcrumbItem,
+ BreadcrumbLink,
+ BreadcrumbList,
+ BreadcrumbPage,
+ BreadcrumbSeparator,
+} from "@/components/ui/breadcrumb";
+import Hourglass from "@/components/Hourglass";
+import { Badge } from "@/components/ui/badge";
+import { Separator } from "@/components/ui/separator";
+import Link from "next/link";
+import { CreateWorkspaceDialog } from "@/components/create-workspace-dialog";
+import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
+import router from "next/router";
+
+export default function WorkspacesPage() {
+ const { workspaces, currentWorkspace, fetchWorkspaces, switchWorkspace } = useAuth();
+ const [isLoading, setIsLoading] = useState(true);
+ const [isCreateDialogOpen, setIsCreateDialogOpen] = useState(false);
+
+ useEffect(() => {
+ const loadWorkspaces = async () => {
+ setIsLoading(true);
+ try {
+ await fetchWorkspaces();
+ } catch (error) {
+ console.error("Error loading workspaces:", error);
+ } finally {
+ setIsLoading(false);
+ }
+ };
+
+ loadWorkspaces();
+ }, []);
+
+ const handleWorkspaceSwitch = async (workspaceName: string) => {
+ try {
+ setIsLoading(true);
+ await switchWorkspace(workspaceName);
+ } catch (error) {
+ console.error("Error switching workspace:", error);
+ } finally {
+ setIsLoading(false);
+ }
+ };
+
+ if (isLoading) {
+ return (
+
+
+
+ Loading workspaces...
+
+
+ );
+ }
+
+ return (
+ <>
+
+
+
+
+
+ Home
+
+
+
+ Workspaces
+
+
+
+
+
+
+
+
Workspaces
+
+ Manage your workspaces and team members
+
+
+
setIsCreateDialogOpen(true)}>
+
+ New Workspace
+
+
+
+
+
+ All Workspaces
+ Current Workspace
+
+
+
+
+ {workspaces.map((workspace) => (
+
+
+
+ {workspace.is_default && (
+ Default
+ )}
+ {currentWorkspace?.workspace_id === workspace.workspace_id && (
+ Current
+ )}
+
+
+
+
+
+
{workspace.workspace_name}
+
+
+ ID: {workspace.workspace_id}
+
+
+
+
+
+ API Quota:
+ {workspace.api_daily_quota} calls/day
+
+
+ Content Quota:
+ {workspace.content_quota} experiments
+
+
+ API Key:
+ {workspace.api_key_first_characters}•••••
+
+
+ Created:
+ {new Date(workspace.created_datetime_utc).toLocaleDateString()}
+
+
+
+
+
+ handleWorkspaceSwitch(workspace.workspace_name)}
+ >
+ Switch to
+
+
+
+
+ Manage
+
+
+
+
+ ))}
+
+
+
+
+ {currentWorkspace && (
+
+
+
+
+
+
+
+
+ {currentWorkspace.workspace_name}
+
+ Workspace ID: {currentWorkspace.workspace_id}
+
+
+
+
+
+
+
+
Workspace Details
+
+
Created:
+
{new Date(currentWorkspace.created_datetime_utc).toLocaleDateString()}
+
+
Last Updated:
+
{new Date(currentWorkspace.updated_datetime_utc).toLocaleDateString()}
+
+
API Daily Quota:
+
{currentWorkspace.api_daily_quota} calls/day
+
+
Content Quota:
+
{currentWorkspace.content_quota} experiments
+
+
+
+
+
API Configuration
+
+
API Key Prefix:
+
{currentWorkspace.api_key_first_characters}•••••
+
+
Key Last Rotated:
+
{new Date(currentWorkspace.api_key_updated_datetime_utc).toLocaleDateString()}
+
+
+
+
+
+ router.push(`/workspaces/${currentWorkspace.workspace_id}`)}
+ >
+
+ Manage Workspace
+
+
+
+
+ )}
+
+
+
+
+
+ >
+ );
+}
diff --git a/frontend/src/app/(protected)/workspaces/types.ts b/frontend/src/app/(protected)/workspaces/types.ts
new file mode 100644
index 0000000..6a0c8de
--- /dev/null
+++ b/frontend/src/app/(protected)/workspaces/types.ts
@@ -0,0 +1,22 @@
+export type WorkspaceUser = {
+ user_id: number;
+ username: string;
+ first_name: string;
+ last_name: string;
+ role: string;
+ is_default_workspace: boolean;
+ created_datetime_utc: string;
+};
+
+export type Workspace = {
+ workspace_id: number;
+ workspace_name: string;
+ api_key_first_characters: string;
+ api_daily_quota: number;
+ content_quota: number;
+ created_datetime_utc: string;
+ updated_datetime_utc: string;
+ api_key_updated_datetime_utc: string;
+ is_default: boolean;
+};
+
diff --git a/frontend/src/components/WorkspaceSelector.tsx b/frontend/src/components/WorkspaceSelector.tsx
deleted file mode 100644
index 4692888..0000000
--- a/frontend/src/components/WorkspaceSelector.tsx
+++ /dev/null
@@ -1,142 +0,0 @@
-"use client";
-
-import React, { useState } from "react";
-import { useAuth } from "@/utils/auth";
-import { Button } from "@/components/catalyst/button";
-import {
- Dialog,
- DialogActions,
- DialogBody,
- DialogDescription,
- DialogTitle,
-} from "@/components/catalyst/dialog";
-import { BuildingOfficeIcon, ChevronUpDownIcon, PlusIcon } from "@heroicons/react/20/solid";
-import {
- DropdownItem,
- DropdownLabel,
- DropdownMenu,
- DropdownButton,
- DropdownDivider,
- Dropdown,
-} from "@/components/catalyst/dropdown";
-import { useToast } from "@/hooks/use-toast";
-import { useRouter } from "next/navigation";
-
-export default function WorkspaceSelector() {
- const { currentWorkspace, workspaces, switchWorkspace, isLoading } = useAuth();
- const [isOpen, setIsOpen] = useState(false);
- const { toast } = useToast();
- const router = useRouter();
-
- const handleSwitchWorkspace = async (workspaceName: string) => {
- try {
- await switchWorkspace(workspaceName);
- toast({
- title: "Workspace Changed",
- description: `Switched to workspace: ${workspaceName}`,
- });
- } catch (error) {
- toast({
- title: "Error",
- description: "Failed to switch workspace",
- variant: "destructive",
- });
- }
- };
-
- const handleCreateWorkspace = () => {
- router.push("/workspace/create");
- };
-
- if (isLoading || !currentWorkspace) {
- return (
-
-
-
- Loading...
-
-
-
- );
- }
-
- return (
-
-
-
-
-
- {currentWorkspace.workspace_name}
-
-
-
-
-
- {workspaces.map((workspace) => (
- handleSwitchWorkspace(workspace.workspace_name)}
- >
-
- {workspace.workspace_name}
-
- ))}
-
-
-
-
-
- Create New Workspace
-
-
-
-
-
setIsOpen(false)}>
- Switch Workspace
-
- Select a workspace to switch to
-
-
-
- {workspaces.map((workspace) => (
-
{
- handleSwitchWorkspace(workspace.workspace_name);
- setIsOpen(false);
- }}
- >
-
{workspace.workspace_name}
- {workspace.workspace_id === currentWorkspace.workspace_id && (
-
Current
- )}
-
- ))}
-
-
-
- setIsOpen(false)}>
- Cancel
-
-
-
- Create New
-
-
-
-
- );
-}
diff --git a/frontend/src/components/app-sidebar.tsx b/frontend/src/components/app-sidebar.tsx
index 6c8ee29..8495693 100644
--- a/frontend/src/components/app-sidebar.tsx
+++ b/frontend/src/components/app-sidebar.tsx
@@ -1,16 +1,11 @@
"use client";
import * as React from "react";
import {
- AudioWaveform,
ArrowLeftRightIcon,
LayoutDashboardIcon,
- Command,
- Frame,
- GalleryVerticalEnd,
- Map,
- PieChart,
- Settings2,
FlaskConicalIcon,
+ Settings2,
+ Building,
} from "lucide-react";
import { NavMain } from "@/components/nav-main";
import { NavRecentExperiments } from "@/components/nav-recent-experiments";
@@ -23,128 +18,68 @@ import {
SidebarHeader,
SidebarRail,
} from "@/components/ui/sidebar";
-import api from "@/utils/api";
import { useAuth } from "@/utils/auth";
-type UserDetails = {
- username: string;
- firstName: string;
- lastName: string;
- isActive: boolean;
- isVerified: boolean;
-};
-
-const getUserDetails = async (token: string | null) => {
- try {
- const response = await api.get("/user", {
- headers: {
- Authorization: `Bearer ${token}`,
- },
- });
-
- return {
- username: response.data.username,
- firstName: response.data.first_name,
- lastName: response.data.last_name,
- isActive: response.data.is_active,
- isVerified: response.data.is_verified,
- } as UserDetails;
- } catch (error: unknown) {
- if (error instanceof Error) {
- throw new Error(`Error fetching user details: ${error.message}`);
- } else {
- throw new Error("Error fetching user details");
- }
- }
-};
+const AppSidebar = React.memo(function AppSidebar({
+ ...props
+}: React.ComponentProps) {
+ const { user, firstName, lastName } = useAuth();
-// This is sample data.
-const data = {
- user: {
- name: "shadcn",
- email: "m@example.com",
- avatar: "/avatars/shadcn.jpg",
- },
- workspaces: [
- {
- name: "Acme Inc",
- logo: GalleryVerticalEnd,
- plan: "Enterprise",
- },
- {
- name: "Acme Corp.",
- logo: AudioWaveform,
- plan: "Startup",
- },
- {
- name: "Evil Corp.",
- logo: Command,
- plan: "Free",
- },
- ],
- navMain: [
+ const navMain = [
{
title: "Experiments",
url: "/experiments",
icon: FlaskConicalIcon,
},
- {
- title: "Integration",
- url: "/integration",
- icon: ArrowLeftRightIcon,
- },
{
title: "Dashboard",
url: "#",
icon: LayoutDashboardIcon,
},
+ {
+ title: "Workspaces",
+ url: "/workspaces",
+ icon: Building,
+ },
{
title: "Settings",
url: "#",
icon: Settings2,
},
- ],
- recentExperiments: [
+ ];
+
+ const recentExperiments = [
{
name: "New onboarding flows",
url: "#",
- icon: Frame,
+ icon: FlaskConicalIcon,
},
{
name: "3 different voices",
url: "#",
- icon: PieChart,
+ icon: FlaskConicalIcon,
},
{
name: "AI responses",
url: "#",
- icon: Map,
+ icon: FlaskConicalIcon,
},
- ],
-};
-const AppSidebar = React.memo(function AppSidebar({
- ...props
-}: React.ComponentProps) {
- const { token } = useAuth();
- const [userDetails, setUserDetails] = React.useState(
- null
- );
+ ];
+
+ const userDetails = {
+ firstName: firstName || "?",
+ lastName: lastName || "?",
+ username: user || "loading",
+ };
- React.useEffect(() => {
- if (token) {
- getUserDetails(token)
- .then((data) => setUserDetails(data))
- .catch((error) => console.error(error));
- }
- }, [token]);
return (
-
+
-
-
+
+
diff --git a/frontend/src/components/create-workspace-dialog.tsx b/frontend/src/components/create-workspace-dialog.tsx
new file mode 100644
index 0000000..8a5b75b
--- /dev/null
+++ b/frontend/src/components/create-workspace-dialog.tsx
@@ -0,0 +1,152 @@
+"use client";
+
+import { useState } from "react";
+import { apiCalls } from "@/utils/api";
+import { useAuth } from "@/utils/auth";
+import { useToast } from "@/hooks/use-toast";
+import { useRouter } from "next/navigation";
+import { z } from "zod";
+import { useForm } from "react-hook-form";
+import { zodResolver } from "@hookform/resolvers/zod";
+
+import {
+ Dialog,
+ DialogContent,
+ DialogDescription,
+ DialogFooter,
+ DialogHeader,
+ DialogTitle,
+} from "@/components/ui/dialog";
+import { Button } from "@/components/ui/button";
+import {
+ Form,
+ FormControl,
+ FormDescription,
+ FormField,
+ FormItem,
+ FormLabel,
+ FormMessage,
+} from "@/components/ui/form";
+import { Input } from "@/components/ui/input";
+
+// Create schema for workspace form
+const WorkspaceFormSchema = z.object({
+ workspaceName: z
+ .string()
+ .min(3, "Workspace name must be at least 3 characters")
+ .max(50, "Workspace name must be less than 50 characters"),
+ apiDailyQuota: z
+ .number()
+ .int("API quota must be an integer")
+ .min(1, "API quota must be at least 1")
+ .optional(),
+ contentQuota: z
+ .number()
+ .int("Content quota must be an integer")
+ .min(1, "Content quota must be at least 1")
+ .optional(),
+});
+
+type WorkspaceFormValues = z.infer;
+
+interface CreateWorkspaceDialogProps {
+ open: boolean;
+ onOpenChange: (open: boolean) => void;
+}
+
+export function CreateWorkspaceDialog({
+ open,
+ onOpenChange,
+}: CreateWorkspaceDialogProps) {
+ const { token, fetchWorkspaces } = useAuth();
+ const router = useRouter();
+ const { toast } = useToast();
+ const [isCreating, setIsCreating] = useState(false);
+
+ // Initialize form
+ const form = useForm({
+ resolver: zodResolver(WorkspaceFormSchema),
+ defaultValues: {
+ workspaceName: "",
+ apiDailyQuota: undefined,
+ contentQuota: undefined,
+ },
+ });
+
+ async function onSubmit(data: WorkspaceFormValues) {
+ if (!token) return;
+
+ setIsCreating(true);
+ try {
+ await apiCalls.createWorkspace(
+ token,
+ data.workspaceName,
+ data.apiDailyQuota,
+ data.contentQuota
+ );
+
+ toast({
+ title: "Workspace created",
+ description: `${data.workspaceName} workspace was created successfully.`,
+ variant: "success",
+ });
+
+ // Refresh the workspaces list
+ await fetchWorkspaces();
+
+ // Close the dialog and reset form
+ onOpenChange(false);
+ form.reset();
+
+ // Navigate to workspaces page
+ router.push("/workspaces");
+ } catch (error: any) {
+ toast({
+ title: "Error creating workspace",
+ description: error.message || "Failed to create workspace. Please try again.",
+ variant: "destructive",
+ });
+ } finally {
+ setIsCreating(false);
+ }
+ }
+
+ return (
+
+
+
+ Create New Workspace
+
+ Create a new workspace for your team or project
+
+
+
+
+
+
+
+ );
+}
diff --git a/frontend/src/components/ui/alert-dialog.tsx b/frontend/src/components/ui/alert-dialog.tsx
new file mode 100644
index 0000000..1e2629a
--- /dev/null
+++ b/frontend/src/components/ui/alert-dialog.tsx
@@ -0,0 +1,143 @@
+"use client"
+
+import * as React from "react"
+import * as AlertDialogPrimitive from "@radix-ui/react-alert-dialog"
+
+import { cn } from "@/lib/utils"
+import { buttonVariants } from "@/components/ui/button"
+
+const AlertDialog = AlertDialogPrimitive.Root
+const AlertDialogTrigger = AlertDialogPrimitive.Trigger
+const AlertDialogPortal = AlertDialogPrimitive.Portal
+
+const AlertDialogOverlay = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>(({ className, ...props }, ref) => (
+
+))
+AlertDialogOverlay.displayName = AlertDialogPrimitive.Overlay.displayName
+
+const AlertDialogContent = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>(({ className, ...props }, ref) => (
+
+
+
+
+))
+AlertDialogContent.displayName = AlertDialogPrimitive.Content.displayName
+
+const AlertDialogHeader = ({
+ className,
+ ...props
+}: React.HTMLAttributes) => (
+
+)
+AlertDialogHeader.displayName = "AlertDialogHeader"
+
+const AlertDialogFooter = ({
+ className,
+ ...props
+}: React.HTMLAttributes) => (
+
+)
+AlertDialogFooter.displayName = "AlertDialogFooter"
+
+const AlertDialogTitle = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>(({ className, ...props }, ref) => (
+
+))
+AlertDialogTitle.displayName = AlertDialogPrimitive.Title.displayName
+
+const AlertDialogDescription = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>(({ className, ...props }, ref) => (
+
+))
+AlertDialogDescription.displayName =
+ AlertDialogPrimitive.Description.displayName
+
+const AlertDialogAction = React.forwardRef<
+ HTMLButtonElement,
+ React.ButtonHTMLAttributes
+>(({ className, ...props }, ref) => (
+
+))
+AlertDialogAction.displayName = "AlertDialogAction"
+
+const AlertDialogCancel = React.forwardRef<
+ HTMLButtonElement,
+ React.ButtonHTMLAttributes
+>(({ className, ...props }, ref) => (
+
+))
+AlertDialogCancel.displayName = "AlertDialogCancel"
+
+export {
+ AlertDialog,
+ AlertDialogPortal,
+ AlertDialogOverlay,
+ AlertDialogTrigger,
+ AlertDialogContent,
+ AlertDialogHeader,
+ AlertDialogFooter,
+ AlertDialogTitle,
+ AlertDialogDescription,
+ AlertDialogAction,
+ AlertDialogCancel,
+}
\ No newline at end of file
diff --git a/frontend/src/components/ui/tabs.tsx b/frontend/src/components/ui/tabs.tsx
new file mode 100644
index 0000000..8873b85
--- /dev/null
+++ b/frontend/src/components/ui/tabs.tsx
@@ -0,0 +1,55 @@
+"use client"
+
+import * as React from "react"
+import * as TabsPrimitive from "@radix-ui/react-tabs"
+
+import { cn } from "@/lib/utils"
+
+const Tabs = TabsPrimitive.Root
+
+const TabsList = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>(({ className, ...props }, ref) => (
+
+))
+TabsList.displayName = TabsPrimitive.List.displayName
+
+const TabsTrigger = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>(({ className, ...props }, ref) => (
+
+))
+TabsTrigger.displayName = TabsPrimitive.Trigger.displayName
+
+const TabsContent = React.forwardRef<
+ React.ElementRef,
+ React.ComponentPropsWithoutRef
+>(({ className, ...props }, ref) => (
+
+))
+TabsContent.displayName = TabsPrimitive.Content.displayName
+
+export { Tabs, TabsList, TabsTrigger, TabsContent }
\ No newline at end of file
diff --git a/frontend/src/components/workspace-switcher.tsx b/frontend/src/components/workspace-switcher.tsx
index 3e03c70..df795d9 100644
--- a/frontend/src/components/workspace-switcher.tsx
+++ b/frontend/src/components/workspace-switcher.tsx
@@ -1,7 +1,8 @@
"use client";
import * as React from "react";
-import { ChevronsUpDown, Plus } from "lucide-react";
+import { ChevronsUpDown, FlaskConical } from "lucide-react";
+import { useAuth } from "@/utils/auth";
import {
DropdownMenu,
@@ -19,39 +20,47 @@ import {
useSidebar,
} from "@/components/ui/sidebar";
-export function WorkspaceSwitcher({
- workspaces,
-}: {
- workspaces: {
- name: string;
- logo: React.ElementType;
- plan: string;
- }[];
-}) {
+export function WorkspaceSwitcher() {
const { isMobile } = useSidebar();
- const [activeWorkspace, setActiveWorkspace] = React.useState(workspaces[0]);
+ const { currentWorkspace, workspaces, switchWorkspace } = useAuth();
+ const [isLoading, setIsLoading] = React.useState(false);
- if (!activeWorkspace) {
+ if (!currentWorkspace) {
return null;
}
+ const handleWorkspaceSwitch = async (workspaceName: string) => {
+ if (workspaceName === currentWorkspace.workspace_name) return;
+
+ try {
+ setIsLoading(true);
+ await switchWorkspace(workspaceName);
+ } catch (error) {
+ console.error("Error switching workspace:", error);
+ } finally {
+ setIsLoading(false);
+ }
+ };
+
return (
-
+
- {activeWorkspace.name}
+ {currentWorkspace.workspace_name}
+
+
+ {isLoading ? "Switching..." : "Workspace"}
- {activeWorkspace.plan}
@@ -67,25 +76,27 @@ export function WorkspaceSwitcher({
{workspaces.map((workspace, index) => (
setActiveWorkspace(workspace)}
- className="gap-2 p-2"
+ key={workspace.workspace_id}
+ onClick={() => handleWorkspaceSwitch(workspace.workspace_name)}
+ className={`gap-2 p-2 ${workspace.workspace_id === currentWorkspace.workspace_id ? "bg-accent" : ""}`}
>
-
+
- {workspace.name}
+ {workspace.workspace_name}
⌘{index + 1}
))}
-
-
- Add workspace
-
+
+
+ +
+
+
+ Add workspace
+
+
diff --git a/frontend/src/utils/api.ts b/frontend/src/utils/api.ts
index c5c6e18..597e6e0 100644
--- a/frontend/src/utils/api.ts
+++ b/frontend/src/utils/api.ts
@@ -97,61 +97,75 @@ const registerUser = async (
}
};
-const getCurrentWorkspace = async (token: string | null) => {
+const requestPasswordReset = async (username: string) => {
try {
- const response = await api.get("/workspace/current", {
- headers: {
- Authorization: `Bearer ${token}`,
- },
- });
+ const response = await api.post("/request-password-reset", { username });
return response.data;
} catch (error) {
- throw new Error("Error fetching current workspace");
+ throw new Error("Error requesting password reset");
}
};
-const getAllWorkspaces = async (token: string | null) => {
+const resetPassword = async (token: string, newPassword: string) => {
try {
- const response = await api.get("/workspace/", {
- headers: {
- Authorization: `Bearer ${token}`,
- },
+ const response = await api.post("/reset-password", {
+ token,
+ new_password: newPassword,
});
return response.data;
} catch (error) {
- throw new Error("Error fetching workspaces");
+ throw new Error("Error resetting password");
}
};
-const createWorkspace = async (token: string | null, workspaceData: any) => {
+const verifyEmail = async (token: string) => {
try {
- const response = await api.post("/workspace/", workspaceData, {
+ const response = await api.post("/verify-email", { token });
+ return response.data;
+ } catch (error) {
+ throw new Error("Error verifying email");
+ }
+};
+
+const resendVerification = async (username: string) => {
+ try {
+ const response = await api.post("/resend-verification", { username });
+ return response.data;
+ } catch (error) {
+ throw new Error("Error resending verification email");
+ }
+};
+
+const getUserWorkspaces = async (token: string | null) => {
+ try {
+ const response = await api.get("/workspace/", {
headers: {
Authorization: `Bearer ${token}`,
},
});
return response.data;
} catch (error) {
- throw new Error("Error creating workspace");
+ throw new Error("Error fetching user workspaces");
}
};
-const updateWorkspace = async (token: string | null, workspaceId: number, workspaceData: any) => {
+const getCurrentWorkspace = async (token: string | null) => {
try {
- const response = await api.put(`/workspace/${workspaceId}`, workspaceData, {
+ const response = await api.get("/workspace/current", {
headers: {
Authorization: `Bearer ${token}`,
},
});
return response.data;
} catch (error) {
- throw new Error("Error updating workspace");
+ throw new Error("Error fetching current workspace");
}
};
const switchWorkspace = async (token: string | null, workspaceName: string) => {
try {
- const response = await api.post("/workspace/switch",
+ const response = await api.post(
+ "/workspace/switch",
{ workspace_name: workspaceName },
{
headers: {
@@ -165,68 +179,141 @@ const switchWorkspace = async (token: string | null, workspaceName: string) => {
}
};
-const rotateWorkspaceApiKey = async (token: string | null) => {
+const createWorkspace = async (token: string | null, workspaceName: string,
+ apiDailyQuota?: number, contentQuota?: number) => {
try {
- const response = await api.put("/workspace/rotate-key", {}, {
- headers: {
- Authorization: `Bearer ${token}`,
+ const response = await api.post(
+ "/workspace/",
+ {
+ workspace_name: workspaceName,
+ api_daily_quota: apiDailyQuota,
+ content_quota: contentQuota
},
- });
+ {
+ headers: {
+ Authorization: `Bearer ${token}`,
+ },
+ }
+ );
return response.data;
} catch (error) {
- throw new Error("Error rotating workspace API key");
+ throw new Error("Error creating workspace");
}
};
-const inviteUserToWorkspace = async (token: string | null, inviteData: any) => {
+const updateWorkspace = async (
+ token: string | null,
+ workspaceId: number,
+ workspaceName?: string,
+ apiDailyQuota?: number,
+ contentQuota?: number
+) => {
try {
- const response = await api.post("/workspace/invite", inviteData, {
- headers: {
- Authorization: `Bearer ${token}`,
+ const response = await api.put(
+ `/workspace/${workspaceId}`,
+ {
+ workspace_name: workspaceName,
+ api_daily_quota: apiDailyQuota,
+ content_quota: contentQuota,
},
- });
+ {
+ headers: {
+ Authorization: `Bearer ${token}`,
+ },
+ }
+ );
return response.data;
} catch (error) {
- throw new Error("Error inviting user to workspace");
+ throw new Error("Error updating workspace");
}
};
-const requestPasswordReset = async (username: string) => {
+const rotateWorkspaceApiKey = async (token: string | null) => {
try {
- const response = await api.post("/request-password-reset", { username });
+ const response = await api.put(
+ "/workspace/rotate-key",
+ {},
+ {
+ headers: {
+ Authorization: `Bearer ${token}`,
+ },
+ }
+ );
return response.data;
} catch (error) {
- throw new Error("Error requesting password reset");
+ throw new Error("Error rotating workspace API key");
}
};
-const resetPassword = async (token: string, newPassword: string) => {
+const getWorkspaceById = async (token: string | null, workspaceId: number) => {
try {
- const response = await api.post("/reset-password", {
- token,
- new_password: newPassword,
+ const response = await api.get(`/workspace/${workspaceId}`, {
+ headers: {
+ Authorization: `Bearer ${token}`,
+ },
});
return response.data;
} catch (error) {
- throw new Error("Error resetting password");
+ throw new Error("Error fetching workspace details");
}
};
-const verifyEmail = async (token: string) => {
+const getWorkspaceUsers = async (token: string | null, workspaceId: number) => {
try {
- const response = await api.post("/verify-email", { token });
+ const response = await api.get(`/workspace/${workspaceId}/users`, {
+ headers: {
+ Authorization: `Bearer ${token}`,
+ },
+ });
return response.data;
} catch (error) {
- throw new Error("Error verifying email");
+ throw new Error("Error fetching workspace users");
}
};
-const resendVerification = async (username: string) => {
+const inviteUserToWorkspace = async (
+ token: string | null,
+ email: string,
+ workspaceName: string,
+ role: string
+) => {
try {
- const response = await api.post("/resend-verification", { username });
+ const response = await api.post(
+ "/workspace/invite",
+ {
+ email,
+ workspace_name: workspaceName,
+ role,
+ },
+ {
+ headers: {
+ Authorization: `Bearer ${token}`,
+ },
+ }
+ );
return response.data;
} catch (error) {
- throw new Error("Error resending verification email");
+ throw new Error("Error inviting user to workspace");
+ }
+};
+
+const removeUserFromWorkspace = async (
+ token: string | null,
+ workspaceId: number,
+ username: string
+) => {
+ try {
+ const response = await api.delete(
+ `/workspace/${workspaceId}/users/${username}`,
+ {
+ headers: {
+ Authorization: `Bearer ${token}`,
+ },
+ }
+ );
+ return response.data;
+ } catch (error) {
+ throw new Error("Error removing user from workspace");
}
};
@@ -239,12 +326,15 @@ export const apiCalls = {
resetPassword,
verifyEmail,
resendVerification,
+ getUserWorkspaces,
getCurrentWorkspace,
- getAllWorkspaces,
+ switchWorkspace,
createWorkspace,
updateWorkspace,
- switchWorkspace,
rotateWorkspaceApiKey,
+ getWorkspaceById,
+ getWorkspaceUsers,
inviteUserToWorkspace,
+ removeUserFromWorkspace,
};
export default api;
diff --git a/frontend/src/utils/auth.tsx b/frontend/src/utils/auth.tsx
index ee6650d..46a243b 100644
--- a/frontend/src/utils/auth.tsx
+++ b/frontend/src/utils/auth.tsx
@@ -7,12 +7,19 @@ type Workspace = {
workspace_id: number;
workspace_name: string;
api_key_first_characters: string;
+ api_daily_quota: number;
+ content_quota: number;
+ created_datetime_utc: string;
+ updated_datetime_utc: string;
+ api_key_updated_datetime_utc: string;
is_default: boolean;
};
type AuthContextType = {
token: string | null;
user: string | null;
+ firstName: string | null;
+ lastName: string | null;
isVerified: boolean;
isLoading: boolean;
currentWorkspace: Workspace | null;
@@ -20,7 +27,6 @@ type AuthContextType = {
login: (username: string, password: string) => Promise;
logout: () => void;
loginError: string | null;
- switchWorkspace: (workspaceName: string) => Promise;
loginGoogle: ({
client_id,
credential,
@@ -28,6 +34,9 @@ type AuthContextType = {
client_id: string;
credential: string;
}) => void;
+ fetchWorkspaces: () => Promise;
+ switchWorkspace: (workspaceName: string) => Promise;
+ rotateWorkspaceApiKey: () => Promise;
};
const AuthContext = createContext(undefined);
@@ -52,6 +61,8 @@ const AuthProvider = ({ children }: AuthProviderProps) => {
};
const [user, setUser] = useState(getInitialUsername);
+ const [firstName, setFirstName] = useState(null);
+ const [lastName, setLastName] = useState(null);
const [token, setToken] = useState(getInitialToken);
const [isVerified, setIsVerified] = useState(false);
const [isLoading, setIsLoading] = useState(!!getInitialToken());
@@ -63,39 +74,23 @@ const AuthProvider = ({ children }: AuthProviderProps) => {
const router = useRouter();
useEffect(() => {
- const loadWorkspaceInfo = async () => {
- if (token) {
- try {
- setIsLoading(true);
- // Fetch current workspace
- const currentWorkspaceData = await apiCalls.getCurrentWorkspace(token);
- setCurrentWorkspace(currentWorkspaceData);
-
- // Fetch all workspaces
- const workspacesData = await apiCalls.getAllWorkspaces(token);
- setWorkspaces(workspacesData);
- } catch (error) {
- console.error("Error loading workspace info:", error);
- } finally {
- setIsLoading(false);
- }
- }
- };
-
- loadWorkspaceInfo();
- }, [token]);
-
- // Check verification status on init if token exists
- useEffect(() => {
- const checkVerificationStatus = async () => {
+ const checkUserStatus = async () => {
const currentToken = getInitialToken();
if (currentToken) {
setIsLoading(true);
try {
const userData = await apiCalls.getUser(currentToken);
setIsVerified(userData.is_verified);
+ setFirstName(userData.first_name);
+ setLastName(userData.last_name);
+
+ // Fetch current workspace
+ await fetchCurrentWorkspace();
+
+ // Fetch available workspaces
+ await fetchWorkspaces();
} catch (error) {
- console.error("Error fetching user verification status:", error);
+ console.error("Error fetching user status:", error);
logout();
} finally {
setIsLoading(false);
@@ -103,9 +98,73 @@ const AuthProvider = ({ children }: AuthProviderProps) => {
}
};
- checkVerificationStatus();
+ checkUserStatus();
}, []);
+ const fetchCurrentWorkspace = async () => {
+ if (!token) return;
+
+ try {
+ const workspaceData = await apiCalls.getCurrentWorkspace(token);
+ setCurrentWorkspace(workspaceData);
+ } catch (error) {
+ console.error("Error fetching current workspace:", error);
+ }
+ };
+
+ const fetchWorkspaces = async () => {
+ if (!token) return;
+
+ try {
+ const workspacesData = await apiCalls.getUserWorkspaces(token);
+ setWorkspaces(workspacesData);
+ } catch (error) {
+ console.error("Error fetching user workspaces:", error);
+ }
+ };
+
+ const switchWorkspace = async (workspaceName: string) => {
+ if (!token) return;
+
+ try {
+ setIsLoading(true);
+ const authResponse = await apiCalls.switchWorkspace(token, workspaceName);
+
+ // Update token and other auth details
+ localStorage.setItem("ee-token", authResponse.access_token);
+ setToken(authResponse.access_token);
+
+ // Refresh workspace data
+ await fetchCurrentWorkspace();
+
+ return;
+ } catch (error) {
+ console.error("Error switching workspace:", error);
+ throw error;
+ } finally {
+ setIsLoading(false);
+ }
+ };
+
+ const rotateWorkspaceApiKey = async () => {
+ if (!token) throw new Error("Not authenticated");
+
+ try {
+ setIsLoading(true);
+ const response = await apiCalls.rotateWorkspaceApiKey(token);
+
+ // Update workspace to reflect key change
+ await fetchCurrentWorkspace();
+
+ return response.new_api_key;
+ } catch (error) {
+ console.error("Error rotating workspace API key:", error);
+ throw error;
+ } finally {
+ setIsLoading(false);
+ }
+ };
+
const login = async (username: string, password: string) => {
const sourcePage = searchParams.has("sourcePage")
? decodeURIComponent(searchParams.get("sourcePage") as string)
@@ -131,19 +190,27 @@ const AuthProvider = ({ children }: AuthProviderProps) => {
try {
const userData = await apiCalls.getUser(access_token);
setIsVerified(userData.is_verified);
+ setFirstName(userData.first_name);
+ setLastName(userData.last_name);
} catch (error) {
console.error("Error fetching user verification status:", error);
}
}
+ // Fetch current workspace
try {
- const currentWorkspaceData = await apiCalls.getCurrentWorkspace(access_token);
- setCurrentWorkspace(currentWorkspaceData);
+ const workspaceData = await apiCalls.getCurrentWorkspace(access_token);
+ setCurrentWorkspace(workspaceData);
+ } catch (error) {
+ console.error("Error fetching current workspace:", error);
+ }
- const workspacesData = await apiCalls.getAllWorkspaces(access_token);
+ // Fetch all workspaces
+ try {
+ const workspacesData = await apiCalls.getUserWorkspaces(access_token);
setWorkspaces(workspacesData);
} catch (error) {
- console.error("Error loading workspace info:", error);
+ console.error("Error fetching workspaces:", error);
}
// Redirect to verification page if not verified, otherwise to original destination
@@ -168,26 +235,6 @@ const AuthProvider = ({ children }: AuthProviderProps) => {
}
};
- const switchWorkspace = async (workspaceName: string) => {
- try {
- setIsLoading(true);
- const response = await apiCalls.switchWorkspace(token, workspaceName);
-
- localStorage.setItem("ee-token", response.access_token);
- setToken(response.access_token);
-
- const currentWorkspaceData = await apiCalls.getCurrentWorkspace(response.access_token);
- setCurrentWorkspace(currentWorkspaceData);
-
- return response;
- } catch (error) {
- console.error("Error switching workspace:", error);
- throw error;
- } finally {
- setIsLoading(false);
- }
- };
-
const loginGoogle = async ({
client_id,
credential,
@@ -214,16 +261,30 @@ const AuthProvider = ({ children }: AuthProviderProps) => {
setUser(username);
setToken(access_token);
- setIsVerified(true);
+ // Fetch user details
+ try {
+ const userData = await apiCalls.getUser(access_token);
+ setIsVerified(userData.is_verified);
+ setFirstName(userData.first_name);
+ setLastName(userData.last_name);
+ } catch (error) {
+ console.error("Error fetching user details:", error);
+ }
+ // Fetch current workspace
try {
- const currentWorkspaceData = await apiCalls.getCurrentWorkspace(access_token);
- setCurrentWorkspace(currentWorkspaceData);
+ const workspaceData = await apiCalls.getCurrentWorkspace(access_token);
+ setCurrentWorkspace(workspaceData);
+ } catch (error) {
+ console.error("Error fetching current workspace:", error);
+ }
- const workspacesData = await apiCalls.getAllWorkspaces(access_token);
+ // Fetch all workspaces
+ try {
+ const workspacesData = await apiCalls.getUserWorkspaces(access_token);
setWorkspaces(workspacesData);
} catch (error) {
- console.error("Error loading workspace info:", error);
+ console.error("Error fetching workspaces:", error);
}
router.push(sourcePage);
@@ -242,6 +303,8 @@ const AuthProvider = ({ children }: AuthProviderProps) => {
setUser(null);
setToken(null);
setIsVerified(false);
+ setFirstName(null);
+ setLastName(null);
setCurrentWorkspace(null);
setWorkspaces([]);
router.push("/login");
@@ -250,6 +313,8 @@ const AuthProvider = ({ children }: AuthProviderProps) => {
const authValue: AuthContextType = {
token,
user,
+ firstName,
+ lastName,
isVerified,
isLoading,
currentWorkspace,
@@ -257,8 +322,10 @@ const AuthProvider = ({ children }: AuthProviderProps) => {
login,
loginError,
loginGoogle,
- switchWorkspace,
logout,
+ fetchWorkspaces,
+ switchWorkspace,
+ rotateWorkspaceApiKey,
};
return (
From 700fb360d94cb0f86552b3cb80a8639e263f9dd5 Mon Sep 17 00:00:00 2001
From: Jay Prakash <0freerunning@gmail.com>
Date: Thu, 1 May 2025 13:12:43 +0300
Subject: [PATCH 08/74] Remove old sidebar
---
frontend/src/components/sidebar.tsx | 119 ----------------------------
1 file changed, 119 deletions(-)
delete mode 100644 frontend/src/components/sidebar.tsx
diff --git a/frontend/src/components/sidebar.tsx b/frontend/src/components/sidebar.tsx
deleted file mode 100644
index 482b0a6..0000000
--- a/frontend/src/components/sidebar.tsx
+++ /dev/null
@@ -1,119 +0,0 @@
-"use client";
-
-import {
- Sidebar,
- SidebarBody,
- SidebarFooter,
- SidebarHeader,
- SidebarHeading,
- SidebarItem,
- SidebarLabel,
- SidebarSection,
- SidebarSpacer,
- SidebarDivider,
-} from "@/components/catalyst/sidebar";
-import { Avatar } from "@/components/catalyst/avatar";
-import { Dropdown, DropdownButton } from "@/components/catalyst/dropdown";
-import { TeamDropdownMenu, AnchorProps } from "@/components/TeamDropdownMenu";
-import { navItems } from "@/data/navItems";
-import React from "react";
-import { ChevronUpIcon } from "@heroicons/react/16/solid";
-import {
- InboxIcon,
- MagnifyingGlassIcon,
- QuestionMarkCircleIcon,
- SparklesIcon,
- BuildingOfficeIcon,
- UserPlusIcon,
-} from "@heroicons/react/20/solid";
-import { useAuth } from "@/utils/auth";
-import WorkspaceSelector from "./WorkspaceSelector";
-
-export const SidebarComponent = (): React.ReactNode => {
- const { user, currentWorkspace } = useAuth();
-
- return (
-
-
-
-
-
- Search
-
-
-
- Inbox
-
-
-
-
-
- Workspace
-
-
-
- Manage Workspace
-
-
-
- Invite Members
-
-
-
-
- {navItems.map((item) => (
-
-
- {item.label}
-
- ))}
-
- {/*
- New experiments
-
- Modifying voice of chatbot
-
-
- Different module order
-
-
- Asking feedback - 3 ways
-
-
- When to send the message
-
- */}
-
-
-
-
- Support
-
-
-
- Changelog
-
-
-
-
-
-
-
-
-
-
- {user}
-
-
- {user}
-
-
-
-
-
-
-
-
-
- );
-};
From b7ff0b507593353d2772f405cf5d5d43cd9c9466 Mon Sep 17 00:00:00 2001
From: Jay Prakash <0freerunning@gmail.com>
Date: Thu, 1 May 2025 13:58:07 +0300
Subject: [PATCH 09/74] Fix email link
---
backend/app/email.py | 2 +-
frontend/src/app/(protected)/workspaces/[workspaceId]/page.tsx | 1 +
2 files changed, 2 insertions(+), 1 deletion(-)
diff --git a/backend/app/email.py b/backend/app/email.py
index 9afc3ad..d7c7b82 100644
--- a/backend/app/email.py
+++ b/backend/app/email.py
@@ -170,7 +170,7 @@ async def send_workspace_invitation_email(
Hello,
You have been invited by {inviter_email} to join the workspace "{workspace_name}".
You need to create an account to join this workspace.
- Create Your Account
+ Create Your Account