From 3f0b6118898bb4282bf29a748415c53e294f8629 Mon Sep 17 00:00:00 2001 From: Jay Prakash <0freerunning@gmail.com> Date: Thu, 3 Apr 2025 03:49:02 +0530 Subject: [PATCH 01/74] Add user verification and password reset --- backend/add_users_to_db.py | 2 + backend/app/auth/dependencies.py | 52 +++- backend/app/auth/routers.py | 175 ++++++++++++- backend/app/auth/schemas.py | 2 + backend/app/auth/utils.py | 137 ++++++++++ backend/app/config.py | 7 + backend/app/contextual_mab/routers.py | 10 +- backend/app/database.py | 9 + backend/app/email.py | 168 ++++++++++++ backend/app/mab/routers.py | 10 +- backend/app/messages/routers.py | 10 +- backend/app/users/models.py | 61 +++++ backend/app/users/routers.py | 35 ++- backend/app/users/schemas.py | 45 +++- .../97137b6afb58_added_user_fields.py | 45 ++++ backend/requirements.txt | 2 + backend/tests/test.env | 8 + frontend/package-lock.json | 128 +++++++++ frontend/src/app/forgot-password/page.tsx | 163 ++++++++++++ frontend/src/app/login/page.tsx | 42 +-- frontend/src/app/reset-password/page.tsx | 245 ++++++++++++++++++ .../src/app/verification-required/page.tsx | 171 ++++++++++++ frontend/src/app/verify/page.tsx | 159 ++++++++++++ .../src/components/ProtectedComponent.tsx | 14 +- frontend/src/utils/api.ts | 43 +++ frontend/src/utils/auth.tsx | 114 +++++--- 26 files changed, 1773 insertions(+), 84 deletions(-) create mode 100644 backend/app/auth/utils.py create mode 100644 backend/app/email.py create mode 100644 backend/migrations/versions/97137b6afb58_added_user_fields.py create mode 100644 frontend/src/app/forgot-password/page.tsx create mode 100644 frontend/src/app/reset-password/page.tsx create mode 100644 frontend/src/app/verification-required/page.tsx create mode 100644 frontend/src/app/verify/page.tsx diff --git a/backend/add_users_to_db.py b/backend/add_users_to_db.py index 69e5f3f..0aa7183 100644 --- a/backend/add_users_to_db.py +++ b/backend/add_users_to_db.py @@ -36,6 +36,8 @@ api_daily_quota=ADMIN_API_DAILY_QUOTA, created_datetime_utc=datetime.now(timezone.utc), updated_datetime_utc=datetime.now(timezone.utc), + is_active=True, + is_verified=True, ) diff --git a/backend/app/auth/dependencies.py b/backend/app/auth/dependencies.py index c41aea0..1db3022 100644 --- a/backend/app/auth/dependencies.py +++ b/backend/app/auth/dependencies.py @@ -12,7 +12,7 @@ from jwt.exceptions import InvalidTokenError from sqlalchemy.ext.asyncio import AsyncSession -from ..config import CHECK_API_LIMIT, DEFAULT_API_QUOTA, DEFAULT_EXPERIMENTS_QUOTA +from ..config import CHECK_API_LIMIT, DEFAULT_API_QUOTA, DEFAULT_EXPERIMENTS_QUOTA, ENV from ..database import get_async_session from ..users.models import ( UserDB, @@ -20,6 +20,7 @@ get_user_by_api_key, get_user_by_username, save_user_to_db, + update_user_verification_status, ) from ..users.schemas import UserCreate from ..utils import ( @@ -54,6 +55,12 @@ async def authenticate_key( token = credentials.credentials try: user_db = await get_user_by_api_key(token, asession) + + if not user_db.is_active: + raise HTTPException( + status_code=403, detail="Account is inactive. Please contact support." + ) + return user_db except UserNotFoundError as e: raise HTTPException(status_code=403, detail="Invalid API key") from e @@ -67,12 +74,18 @@ async def authenticate_credentials( """ try: user_db = await get_user_by_username(username, asession) + + if not user_db.is_active: + logger.warning(f"Inactive user {username} attempted to login") + return None + if verify_password_salted_hash(password, user_db.hashed_password): # hardcode "fullaccess" now, but may use it in the future return AuthenticatedUser( username=username, access_level="fullaccess", api_key_first_characters=user_db.api_key_first_characters, + is_verified=user_db.is_verified, ) else: return None @@ -87,14 +100,21 @@ async def authenticate_or_create_google_user( asession: AsyncSession, ) -> Optional[AuthenticatedUser]: """ - Check if user exists in Db. If not, create user + Check if user exists in Db. If not, create user. + Google authenticated users are automatically verified. """ try: user_db = await get_user_by_username(google_email, asession) + + if not user_db.is_verified: + asession.add(user_db) + await update_user_verification_status(user_db, True, asession) + return AuthenticatedUser( username=user_db.username, access_level="fullaccess", api_key_first_characters=user_db.api_key_first_characters, + is_verified=user_db.is_verified, ) except UserNotFoundError: user = UserCreate( @@ -103,7 +123,7 @@ async def authenticate_or_create_google_user( api_daily_quota=DEFAULT_API_QUOTA, ) api_key = generate_key() - user_db = await save_user_to_db(user, api_key, asession) + user_db = await save_user_to_db(user, api_key, asession, is_verified=True) await update_api_limits( request.app.state.redis, user_db.username, user_db.api_daily_quota ) @@ -111,6 +131,7 @@ async def authenticate_or_create_google_user( username=user_db.username, access_level="fullaccess", api_key_first_characters=user_db.api_key_first_characters, + is_verified=True, ) @@ -135,6 +156,13 @@ async def get_current_user( # fetch user from database try: user_db = await get_user_by_username(username, asession) + + if not user_db.is_active: + raise HTTPException( + status_code=403, + detail="Account is inactive. Please contact support.", + ) + return user_db except UserNotFoundError as err: raise credentials_exception from err @@ -142,6 +170,24 @@ async def get_current_user( raise credentials_exception from err +async def get_verified_user( + user_db: Annotated[UserDB, Depends(get_current_user)], +) -> UserDB: + """ + Check if the user is verified + """ + if ENV == "testing": + return user_db + + if not user_db.is_verified: + raise HTTPException( + status_code=403, + detail="Account not verified. Please check your email to verify " + "your account.", + ) + return user_db + + def create_access_token(username: str) -> str: """ Create an access token for the user diff --git a/backend/app/auth/routers.py b/backend/app/auth/routers.py index 6fc22d6..eec0c4a 100644 --- a/backend/app/auth/routers.py +++ b/backend/app/auth/routers.py @@ -1,11 +1,26 @@ -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException from fastapi.requests import Request from fastapi.security import OAuth2PasswordRequestForm -from google.auth.transport import requests +from google.auth.transport import requests as google_requests from google.oauth2 import id_token +from redis.asyncio import Redis from sqlalchemy.ext.asyncio import AsyncSession -from ..database import get_async_session +from ..database import get_async_session, get_redis +from ..email import EmailService +from ..users.models import ( + UserNotFoundError, + get_user_by_username, + update_user_password, + update_user_verification_status, +) +from ..users.schemas import ( + EmailVerificationRequest, + MessageResponse, + PasswordResetConfirm, + PasswordResetRequest, +) +from ..utils import setup_logger from .config import NEXT_PUBLIC_GOOGLE_LOGIN_CLIENT_ID from .dependencies import ( authenticate_credentials, @@ -13,6 +28,11 @@ create_access_token, ) from .schemas import AuthenticationDetails, GoogleLoginData +from .utils import ( + generate_password_reset_token, + generate_verification_token, + verify_token, +) TAG_METADATA = { "name": "Authentication", @@ -21,6 +41,9 @@ router = APIRouter(tags=[TAG_METADATA["name"]]) +email_service = EmailService() +logger = setup_logger() + @router.post("/login") async def login( @@ -47,6 +70,7 @@ async def login( token_type="bearer", access_level=user.access_level, username=user.username, + is_verified=user.is_verified, ) @@ -58,13 +82,13 @@ async def login_google( ) -> AuthenticationDetails: """ Verify google token, check if user exists. If user does not exist, create user - Return JWT token for user + Return JWT token for user. Google users are automatically verified. """ try: idinfo = id_token.verify_oauth2_token( login_data.credential, - requests.Request(), + google_requests.Request(), NEXT_PUBLIC_GOOGLE_LOGIN_CLIENT_ID, ) if idinfo["iss"] not in ["accounts.google.com", "https://accounts.google.com"]: @@ -87,4 +111,145 @@ async def login_google( token_type="bearer", access_level=user.access_level, username=user.username, + is_verified=user.is_verified, ) + + +@router.post("/request-password-reset", response_model=MessageResponse) +async def request_password_reset( + reset_request: PasswordResetRequest, + background_tasks: BackgroundTasks, + asession: AsyncSession = Depends(get_async_session), + redis: Redis = Depends(get_redis), +) -> MessageResponse: + """ + Request a password reset email + """ + response_msg = ( + "If an account with this email exists, a password reset link has been sent." + ) + + try: + logger.info(f"Generated password reset token for user {reset_request.username}") + user = await get_user_by_username(reset_request.username, asession) + + logger.info(f"User found: {user.username}") + if not user: + return MessageResponse(message=response_msg) + + token = await generate_password_reset_token(user.user_id, user.username, redis) + background_tasks.add_task( + email_service.send_password_reset_email, user.username, user.username, token + ) + + return MessageResponse(message=response_msg) + except UserNotFoundError: + logger.warning(f"User not found: {reset_request.username}") + return MessageResponse(message=response_msg) + except Exception as e: + logger.exception("An error occurred processing your request") + raise HTTPException( + status_code=500, detail="An error occurred processing your request" + ) from e + + +@router.post("/reset-password", response_model=MessageResponse) +async def reset_password( + reset_data: PasswordResetConfirm, + asession: AsyncSession = Depends(get_async_session), + redis: Redis = Depends(get_redis), +) -> MessageResponse: + """ + Reset a user's password with the provided token + """ + is_valid, payload = await verify_token(reset_data.token, "password_reset", redis) + + if not is_valid or "username" not in payload: + raise HTTPException( + status_code=400, detail="Invalid or expired password reset token" + ) + + try: + user = await get_user_by_username(payload["username"], asession) + asession.add(user) + + await update_user_password(user, reset_data.new_password, asession) + + return MessageResponse(message="Your password has been reset successfully") + except UserNotFoundError as e: + raise HTTPException(status_code=404, detail="User not found") from e + except Exception as e: + raise HTTPException( + status_code=500, detail="An error occurred while resetting your password" + ) from e + + +@router.post("/verify-email", response_model=MessageResponse) +async def verify_email( + verification_data: EmailVerificationRequest, + asession: AsyncSession = Depends(get_async_session), + redis: Redis = Depends(get_redis), +) -> MessageResponse: + """ + Verify a user's email with the provided token + """ + is_valid, payload = await verify_token( + verification_data.token, "verification", redis + ) + + if not is_valid or "username" not in payload: + raise HTTPException( + status_code=400, detail="Invalid or expired verification token" + ) + + try: + user = await get_user_by_username(payload["username"], asession) + + asession.add(user) + await update_user_verification_status(user, True, asession) + + return MessageResponse(message="Your email has been verified successfully") + except UserNotFoundError as e: + raise HTTPException(status_code=404, detail="User not found") from e + except Exception as e: + raise HTTPException( + status_code=500, detail="An error occurred while verifying your email" + ) from e + + +@router.post("/resend-verification", response_model=MessageResponse) +async def resend_verification( + reset_request: PasswordResetRequest, + background_tasks: BackgroundTasks, + asession: AsyncSession = Depends(get_async_session), + redis: Redis = Depends(get_redis), +) -> MessageResponse: + """ + Resend verification email + """ + response_msg = ( + "If an account with this email exists, a password reset link has been sent." + ) + + try: + user = await get_user_by_username(reset_request.username, asession) + + if not user: + return MessageResponse(message=response_msg) + + if user.is_verified: + return MessageResponse(message="Your account is already verified") + + token = await generate_verification_token(user.user_id, user.username, redis) + + background_tasks.add_task( + email_service.send_verification_email, user.username, user.username, token + ) + + return MessageResponse(message=response_msg) + except UserNotFoundError: + return MessageResponse(message=response_msg) + except Exception as e: + raise HTTPException( + status_code=500, detail="An error occurred processing your request" + ) from e diff --git a/backend/app/auth/schemas.py b/backend/app/auth/schemas.py index c8da6d7..9b4b853 100644 --- a/backend/app/auth/schemas.py +++ b/backend/app/auth/schemas.py @@ -14,6 +14,7 @@ class AuthenticatedUser(BaseModel): username: str access_level: AccessLevel api_key_first_characters: str + is_verified: bool model_config = ConfigDict(from_attributes=True) @@ -38,5 +39,6 @@ class AuthenticationDetails(BaseModel): access_level: AccessLevel api_key_first_characters: str username: str + is_verified: bool model_config = ConfigDict(from_attributes=True) diff --git a/backend/app/auth/utils.py b/backend/app/auth/utils.py new file mode 100644 index 0000000..d6b2862 --- /dev/null +++ b/backend/app/auth/utils.py @@ -0,0 +1,137 @@ +import secrets +from datetime import datetime, timedelta, timezone +from typing import Dict, Tuple + +import jwt +from redis.asyncio import Redis + +from ..utils import setup_logger +from .config import JWT_ALGORITHM, JWT_SECRET + +# Token settings +VERIFICATION_TOKEN_EXPIRE_MINUTES = 1440 # 24 hours +PASSWORD_RESET_TOKEN_EXPIRE_MINUTES = 30 # 30 minutes + +logger = setup_logger() + + +async def generate_verification_token(user_id: int, username: str, redis: Redis) -> str: + """ + Generates a verification token for account activation + + Args: + user_id: The user's ID + username: The user's username (email) + redis: Redis connection + + Returns: + JWT token for email verification + """ + # Generate JWT token + token_jti = secrets.token_hex(16) # Add unique ID to prevent token reuse + payload = { + "sub": str(user_id), + "username": username, + "type": "verification", + "exp": datetime.now(timezone.utc) + + timedelta(minutes=VERIFICATION_TOKEN_EXPIRE_MINUTES), + "iat": datetime.now(timezone.utc), + "jti": token_jti, + } + + token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM) + + # Store token in Redis with expiry for additional security + await redis.set( + f"verification_token:{token_jti}", + str(user_id), + ex=VERIFICATION_TOKEN_EXPIRE_MINUTES * 60, + ) + + logger.info(f"Generated verification token for user {user_id}") + return token + + +async def generate_password_reset_token( + user_id: int, username: str, redis: Redis +) -> str: + """ + Generates a token for password reset + + Args: + user_id: The user's ID + username: The user's username (email) + redis: Redis connection + + Returns: + JWT token for password reset + """ + # Generate JWT token + token_jti = secrets.token_hex(16) # Add unique ID to prevent token reuse + payload = { + "sub": str(user_id), + "username": username, + "type": "password_reset", + "exp": datetime.now(timezone.utc) + + timedelta(minutes=PASSWORD_RESET_TOKEN_EXPIRE_MINUTES), + "iat": datetime.now(timezone.utc), + "jti": token_jti, + } + + token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM) + + # Store token in Redis with expiry for additional security + await redis.set( + f"password_reset_token:{token_jti}", + str(user_id), + ex=PASSWORD_RESET_TOKEN_EXPIRE_MINUTES * 60, + ) + + logger.info(f"Generated password reset token for user {user_id}") + return token + + +async def verify_token(token: str, token_type: str, redis: Redis) -> Tuple[bool, Dict]: + """ + Verifies a token and returns user information if valid + + Args: + token: The JWT token + token_type: Either "verification" or "password_reset" + redis: Redis connection + + Returns: + Tuple of (is_valid, payload) + """ + try: + # Decode the token + payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM]) + + # Check token type + if payload.get("type") != token_type: + logger.warning( + f"Invalid token type: expected {token_type}, got {payload.get('type')}" + ) + return False, {} + + # Check if token is in Redis (hasn't been used) + token_key = f"{token_type}_token:{payload['jti']}" + stored_user_id = await redis.get(token_key) + + if not stored_user_id or stored_user_id.decode() != payload["sub"]: + logger.warning( + "Token validation failed: token not found in Redis or user_id mismatch" + ) + return False, {} + + # If verification successful, invalidate token to prevent reuse + await redis.delete(token_key) + + logger.info( + f"Successfully verified {token_type} token for user {payload['sub']}" + ) + return True, payload + + except jwt.PyJWTError as e: + logger.error(f"JWT token verification error: {str(e)}") + return False, {} diff --git a/backend/app/config.py b/backend/app/config.py index 89de874..55ec33e 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -1,5 +1,6 @@ import os +ENV = os.environ.get("ENV", "development") POSTGRES_USER = os.environ.get("POSTGRES_USER", "postgres") POSTGRES_PASSWORD = os.environ.get("POSTGRES_PASSWORD", "postgres") POSTGRES_HOST = os.environ.get("POSTGRES_HOST", "localhost") @@ -17,3 +18,9 @@ DEFAULT_API_QUOTA = int(os.environ.get("DEFAULT_API_QUOTA", 100)) CHECK_API_LIMIT = os.environ.get("CHECK_API_LIMIT", True) CHECK_EXPERIMENTS_LIMIT = os.environ.get("CHECK_EXPERIMENTS_LIMIT", True) + +SES_REGION = os.environ.get("SES_REGION", None) +SES_SENDER_EMAIL = os.environ.get("SES_SENDER_EMAIL", None) +FRONTEND_URL = os.environ.get("FRONTEND_URL", None) +AWS_ACCESS_KEY_ID = os.environ.get("AWS_ACCESS_KEY_ID", None) +AWS_SECRET_ACCESS_KEY = os.environ.get("AWS_SECRET_ACCESS_KEY", None) diff --git a/backend/app/contextual_mab/routers.py b/backend/app/contextual_mab/routers.py index dd56b94..95a6524 100644 --- a/backend/app/contextual_mab/routers.py +++ b/backend/app/contextual_mab/routers.py @@ -4,7 +4,7 @@ from fastapi.exceptions import HTTPException from sqlalchemy.ext.asyncio import AsyncSession -from ..auth.dependencies import authenticate_key, get_current_user +from ..auth.dependencies import authenticate_key, get_verified_user from ..database import get_async_session from ..models import get_notifications_from_db, save_notifications_to_db from ..schemas import ContextType, NotificationsResponse, Outcome @@ -35,7 +35,7 @@ @router.post("/", response_model=ContextualBanditResponse) async def create_contextual_mabs( experiment: ContextualBandit, - user_db: Annotated[UserDB, Depends(get_current_user)], + user_db: Annotated[UserDB, Depends(get_verified_user)], asession: AsyncSession = Depends(get_async_session), ) -> ContextualBanditResponse | HTTPException: """ @@ -55,7 +55,7 @@ async def create_contextual_mabs( @router.get("/", response_model=list[ContextualBanditResponse]) async def get_contextual_mabs( - user_db: Annotated[UserDB, Depends(get_current_user)], + user_db: Annotated[UserDB, Depends(get_verified_user)], asession: AsyncSession = Depends(get_async_session), ) -> list[ContextualBanditResponse]: """ @@ -89,7 +89,7 @@ async def get_contextual_mabs( @router.get("/{experiment_id}", response_model=ContextualBanditResponse) async def get_contextual_mab( experiment_id: int, - user_db: Annotated[UserDB, Depends(get_current_user)], + user_db: Annotated[UserDB, Depends(get_verified_user)], asession: AsyncSession = Depends(get_async_session), ) -> ContextualBanditResponse | HTTPException: """ @@ -116,7 +116,7 @@ async def get_contextual_mab( @router.delete("/{experiment_id}", response_model=dict) async def delete_contextual_mab( experiment_id: int, - user_db: Annotated[UserDB, Depends(get_current_user)], + user_db: Annotated[UserDB, Depends(get_verified_user)], asession: AsyncSession = Depends(get_async_session), ) -> dict: """ diff --git a/backend/app/database.py b/backend/app/database.py index e944dd2..7265e9e 100644 --- a/backend/app/database.py +++ b/backend/app/database.py @@ -2,6 +2,8 @@ from collections.abc import AsyncGenerator, Generator from typing import ContextManager +from fastapi import Request +from redis.asyncio import Redis from sqlalchemy.engine import URL, Engine, create_engine from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine from sqlalchemy.orm import Session @@ -81,3 +83,10 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]: get_sqlalchemy_async_engine(), expire_on_commit=False ) as async_session: yield async_session + + +async def get_redis(request: Request) -> Redis: + """ + Returns the Redis client for the request. + """ + return request.app.state.redis diff --git a/backend/app/email.py b/backend/app/email.py new file mode 100644 index 0000000..be59269 --- /dev/null +++ b/backend/app/email.py @@ -0,0 +1,168 @@ +from typing import Any, Dict, Optional + +import boto3 +from botocore.exceptions import ClientError +from fastapi import HTTPException + +from .config import ( + AWS_ACCESS_KEY_ID, + AWS_SECRET_ACCESS_KEY, + ENV, + FRONTEND_URL, + SES_REGION, + SES_SENDER_EMAIL, +) +from .utils import setup_logger + +logger = setup_logger() + + +class EmailService: + """Service to send emails via AWS SES""" + + def __init__( + self, + aws_region: Optional[str] = None, + sender_email: Optional[str] = None, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + ) -> None: + """Initialize the email service with AWS credentials""" + self.aws_region = aws_region or SES_REGION + self.sender_email = sender_email or SES_SENDER_EMAIL + + session_kwargs: Dict[str, Any] = {"region_name": self.aws_region} + + aws_access_key = aws_access_key_id or AWS_ACCESS_KEY_ID + aws_secret_key = aws_secret_access_key or AWS_SECRET_ACCESS_KEY + + if aws_access_key and aws_secret_key: + session_kwargs.update( + { + "aws_access_key_id": aws_access_key, + "aws_secret_access_key": aws_secret_key, + } + ) + + self.client = boto3.client("ses", **session_kwargs) + + async def send_verification_email( + self, email: str, username: str, token: str + ) -> Dict[str, Any]: + """ + Send account verification email + """ + verification_url = f"{FRONTEND_URL}/verify?token={token}" + + subject = "Verify Your Account" + html_body = f""" + + + +

Account Verification

+

Hello {username},

+

Thank you for signing up. Please click the link below to verify + your account:

+

Verify My Account

+

This link will expire in 24 hours.

+

If you did not create this account, please ignore this email.

+ + + """ + text_body = f""" + Account Verification + + Hello {username}, + + Thank you for signing up. Please use the link below to verify your account: + + {verification_url} + + This link will expire in 24 hours. + + If you did not create this account, please ignore this email. + """ + + return await self._send_email(email, subject, html_body, text_body) + + async def send_password_reset_email( + self, email: str, username: str, token: str + ) -> Dict[str, Any]: + """ + Send password reset email + """ + reset_url = f"{FRONTEND_URL}/reset-password?token={token}" + + subject = "Password Reset Request" + html_body = f""" + + + +

Password Reset

+

Hello {username},

+

We received a request to reset your password. Please click the link below + to set a new password:

+

Reset My Password

+

This link will expire in 30 minutes.

+

If you did not request a password reset, please ignore this email.

+ + + """ + text_body = f""" + Password Reset + + Hello {username}, + + We received a request to reset your password. Please use the link below + to set a new password: + + {reset_url} + + This link will expire in 30 minutes. + + If you did not request a password reset, please ignore this email. + """ + + return await self._send_email(email, subject, html_body, text_body) + + async def _send_email( + self, recipient: str, subject: str, html_body: str, text_body: str + ) -> Dict[str, Any]: + """ + Send an email using AWS SES + """ + if ENV == "testing" or self.client is None: + logger.info(f"[MOCK EMAIL] To: {recipient}, Subject: {subject}") + logger.info(f"[MOCK EMAIL] Text: {text_body[:100]}...") + return {"MessageId": "mock-message-id"} + + try: + response = self.client.send_email( + Source=self.sender_email, + Destination={ + "ToAddresses": [recipient], + }, + Message={ + "Subject": { + "Data": subject, + }, + "Body": { + "Text": { + "Data": text_body, + }, + "Html": { + "Data": html_body, + }, + }, + }, + ) + logger.info( + f"Email sent to {recipient}! Message ID: {response.get('MessageId')}" + ) + return response + except ClientError as e: + error_message = f"Error sending email to {recipient}: {e}" + logger.error(error_message) + raise HTTPException( + status_code=500, detail=f"Failed to send email: {str(e)}" + ) from e diff --git a/backend/app/mab/routers.py b/backend/app/mab/routers.py index 75b2cc0..86cd179 100644 --- a/backend/app/mab/routers.py +++ b/backend/app/mab/routers.py @@ -4,7 +4,7 @@ from fastapi.exceptions import HTTPException from sqlalchemy.ext.asyncio import AsyncSession -from ..auth.dependencies import authenticate_key, get_current_user +from ..auth.dependencies import authenticate_key, get_verified_user from ..database import get_async_session from ..models import get_notifications_from_db, save_notifications_to_db from ..schemas import NotificationsResponse, Outcome, RewardLikelihood @@ -33,7 +33,7 @@ @router.post("/", response_model=MultiArmedBanditResponse) async def create_mab( experiment: MultiArmedBandit, - user_db: Annotated[UserDB, Depends(get_current_user)], + user_db: Annotated[UserDB, Depends(get_verified_user)], asession: AsyncSession = Depends(get_async_session), ) -> MultiArmedBanditResponse: """ @@ -55,7 +55,7 @@ async def create_mab( @router.get("/", response_model=list[MultiArmedBanditResponse]) async def get_mabs( - user_db: Annotated[UserDB, Depends(get_current_user)], + user_db: Annotated[UserDB, Depends(get_verified_user)], asession: AsyncSession = Depends(get_async_session), ) -> list[MultiArmedBanditResponse]: """ @@ -88,7 +88,7 @@ async def get_mabs( @router.get("/{experiment_id}", response_model=MultiArmedBanditResponse) async def get_mab( experiment_id: int, - user_db: Annotated[UserDB, Depends(get_current_user)], + user_db: Annotated[UserDB, Depends(get_verified_user)], asession: AsyncSession = Depends(get_async_session), ) -> MultiArmedBanditResponse: """ @@ -115,7 +115,7 @@ async def get_mab( @router.delete("/{experiment_id}", response_model=dict) async def delete_mab( experiment_id: int, - user_db: Annotated[UserDB, Depends(get_current_user)], + user_db: Annotated[UserDB, Depends(get_verified_user)], asession: AsyncSession = Depends(get_async_session), ) -> dict: """ diff --git a/backend/app/messages/routers.py b/backend/app/messages/routers.py index 31eff9c..fea7bbe 100644 --- a/backend/app/messages/routers.py +++ b/backend/app/messages/routers.py @@ -3,7 +3,7 @@ from fastapi import APIRouter, Depends from sqlalchemy.ext.asyncio import AsyncSession -from ..auth.dependencies import get_current_user +from ..auth.dependencies import get_verified_user from ..database import get_async_session from ..users.models import UserDB from .models import EventMessageDB, MessageDB @@ -14,7 +14,7 @@ @router.get("/", response_model=list[MessageResponse]) async def get_messages( - user_db: Annotated[UserDB, Depends(get_current_user)], + user_db: Annotated[UserDB, Depends(get_verified_user)], asession: AsyncSession = Depends(get_async_session), ) -> list[MessageResponse]: """ @@ -27,7 +27,7 @@ async def get_messages( @router.post("/", response_model=MessageResponse) async def create_message( message: EventMessageCreate, - user_db: Annotated[UserDB, Depends(get_current_user)], + user_db: Annotated[UserDB, Depends(get_verified_user)], asession: AsyncSession = Depends(get_async_session), ) -> MessageResponse: """ @@ -47,7 +47,7 @@ async def create_message( @router.delete("/", response_model=list[MessageResponse]) async def delete_messages( message_ids: list[int], - user_db: Annotated[UserDB, Depends(get_current_user)], + user_db: Annotated[UserDB, Depends(get_verified_user)], asession: AsyncSession = Depends(get_async_session), ) -> list[MessageResponse]: """ @@ -63,7 +63,7 @@ async def delete_messages( @router.patch("/", response_model=list[MessageResponse]) async def mark_messages_as_read( message_read_toggle: MessageReadToggle, - user_db: Annotated[UserDB, Depends(get_current_user)], + user_db: Annotated[UserDB, Depends(get_verified_user)], asession: AsyncSession = Depends(get_async_session), ) -> list[MessageResponse]: """ diff --git a/backend/app/users/models.py b/backend/app/users/models.py index 5f1b6d1..6500c3a 100644 --- a/backend/app/users/models.py +++ b/backend/app/users/models.py @@ -1,6 +1,7 @@ from datetime import datetime, timezone from sqlalchemy import ( + Boolean, DateTime, Integer, String, @@ -48,6 +49,11 @@ class UserDB(Base): updated_datetime_utc: Mapped[datetime] = mapped_column( DateTime(timezone=True), nullable=False ) + access_level: Mapped[str] = mapped_column( + String, nullable=False, default="fullaccess" + ) + is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) + is_verified: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) def __repr__(self) -> str: """Pretty Print""" @@ -58,6 +64,7 @@ async def save_user_to_db( user: UserCreateWithPassword | UserCreate, api_key: str, asession: AsyncSession, + is_verified: bool = False, ) -> UserDB: """ Saves a user in the database @@ -90,6 +97,9 @@ async def save_user_to_db( api_key_first_characters=api_key[:5], created_datetime_utc=datetime.now(timezone.utc), updated_datetime_utc=datetime.now(timezone.utc), + is_active=True, + is_verified=is_verified, + access_level="fullaccess", ) asession.add(user_db) @@ -119,6 +129,57 @@ async def update_user_api_key( return user_db +async def update_user_verification_status( + user_db: UserDB, + is_verified: bool, + asession: AsyncSession, +) -> UserDB: + """ + Updates a user's verification status + """ + user_db.is_verified = is_verified + user_db.updated_datetime_utc = datetime.now(timezone.utc) + + await asession.commit() + await asession.refresh(user_db) + + return user_db + + +async def update_user_active_status( + user_db: UserDB, + is_active: bool, + asession: AsyncSession, +) -> UserDB: + """ + Updates a user's active status + """ + user_db.is_active = is_active + user_db.updated_datetime_utc = datetime.now(timezone.utc) + + await asession.commit() + await asession.refresh(user_db) + + return user_db + + +async def update_user_password( + user_db: UserDB, + new_password: str, + asession: AsyncSession, +) -> UserDB: + """ + Updates a user's password + """ + user_db.hashed_password = get_password_salted_hash(new_password) + user_db.updated_datetime_utc = datetime.now(timezone.utc) + + await asession.commit() + await asession.refresh(user_db) + + return user_db + + async def get_user_by_username( username: str, asession: AsyncSession, diff --git a/backend/app/users/routers.py b/backend/app/users/routers.py index 435109a..4bf80dd 100644 --- a/backend/app/users/routers.py +++ b/backend/app/users/routers.py @@ -1,13 +1,15 @@ from typing import Annotated -from fastapi import APIRouter, Depends -from fastapi.exceptions import HTTPException +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException from fastapi.requests import Request +from redis.asyncio import Redis from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession -from ..auth.dependencies import get_current_user -from ..database import get_async_session +from ..auth.dependencies import get_current_user, get_verified_user +from ..auth.utils import generate_verification_token +from ..database import get_async_session, get_redis +from ..email import EmailService from ..users.models import ( UserAlreadyExistsError, UserDB, @@ -17,21 +19,25 @@ from ..utils import generate_key, setup_logger, update_api_limits from .schemas import KeyResponse, UserCreate, UserCreateWithPassword, UserRetrieve +# Router setup TAG_METADATA = { "name": "Admin", "description": "_Requires user login._ Only administrator user has access to these " "endpoints.", } -router = APIRouter(prefix="/user", tags=["Admin"]) +router = APIRouter(prefix="/user", tags=["Users"]) logger = setup_logger() +email_service = EmailService() @router.post("/", response_model=UserCreate) async def create_user( user: UserCreateWithPassword, request: Request, + background_tasks: BackgroundTasks, asession: AsyncSession = Depends(get_async_session), + redis: Redis = Depends(get_redis), ) -> UserCreate | None: """ Create user endpoint. @@ -43,9 +49,19 @@ async def create_user( user=user, api_key=new_api_key, asession=asession, + is_verified=False, ) - await update_api_limits( - request.app.state.redis, user_new.username, user_new.api_daily_quota + await update_api_limits(redis, user_new.username, user_new.api_daily_quota) + + token = await generate_verification_token( + user_new.user_id, user_new.username, redis + ) + + background_tasks.add_task( + email_service.send_verification_email, + user_new.username, + user_new.username, + token, ) return UserCreate( @@ -76,12 +92,15 @@ async def get_user( api_key_updated_datetime_utc=user_db.api_key_updated_datetime_utc, created_datetime_utc=user_db.created_datetime_utc, updated_datetime_utc=user_db.updated_datetime_utc, + is_active=user_db.is_active, + is_verified=user_db.is_verified, + access_level=user_db.access_level, ) @router.put("/rotate-key", response_model=KeyResponse) async def get_new_api_key( - user_db: Annotated[UserDB, Depends(get_current_user)], + user_db: Annotated[UserDB, Depends(get_verified_user)], asession: AsyncSession = Depends(get_async_session), ) -> KeyResponse | None: """ diff --git a/backend/app/users/schemas.py b/backend/app/users/schemas.py index d773de3..5bb3294 100644 --- a/backend/app/users/schemas.py +++ b/backend/app/users/schemas.py @@ -1,10 +1,9 @@ from datetime import datetime from typing import Optional -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, EmailStr -# not yet used. class UserCreate(BaseModel): """ Pydantic model for user creation @@ -19,7 +18,7 @@ class UserCreate(BaseModel): class UserCreateWithPassword(UserCreate): """ - Pydantic model for user creation + Pydantic model for user creation with password. """ password: str @@ -39,6 +38,9 @@ class UserRetrieve(BaseModel): api_key_updated_datetime_utc: datetime created_datetime_utc: datetime updated_datetime_utc: datetime + is_active: bool + is_verified: bool + access_level: str model_config = ConfigDict(from_attributes=True) @@ -51,3 +53,40 @@ class KeyResponse(BaseModel): username: str new_api_key: str model_config = ConfigDict(from_attributes=True) + + +class PasswordResetRequest(BaseModel): + """ + Pydantic model for password reset request + """ + + username: EmailStr + model_config = ConfigDict(from_attributes=True) + + +class PasswordResetConfirm(BaseModel): + """ + Pydantic model for password reset confirmation + """ + + token: str + new_password: str + model_config = ConfigDict(from_attributes=True) + + +class EmailVerificationRequest(BaseModel): + """ + Pydantic model for email verification + """ + + token: str + model_config = ConfigDict(from_attributes=True) + + +class MessageResponse(BaseModel): + """ + Pydantic model for generic message responses + """ + + message: str + model_config = ConfigDict(from_attributes=True) diff --git a/backend/migrations/versions/97137b6afb58_added_user_fields.py b/backend/migrations/versions/97137b6afb58_added_user_fields.py new file mode 100644 index 0000000..5743a04 --- /dev/null +++ b/backend/migrations/versions/97137b6afb58_added_user_fields.py @@ -0,0 +1,45 @@ +"""Added user fields + +Revision ID: 97137b6afb58 +Revises: d95b5c0590c3 +Create Date: 2025-04-02 21:09:56.417771 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "97137b6afb58" +down_revision: Union[str, None] = "d95b5c0590c3" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "users", + sa.Column( + "access_level", sa.String(), nullable=False, server_default="fullaccess" + ), + ) + op.add_column( + "users", + sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"), + ) + op.add_column( + "users", + sa.Column("is_verified", sa.Boolean(), nullable=False, server_default="false"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("users", "is_verified") + op.drop_column("users", "is_active") + op.drop_column("users", "access_level") + # ### end Alembic commands ### diff --git a/backend/requirements.txt b/backend/requirements.txt index 27b026d..18633d3 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -13,3 +13,5 @@ scikit-learn==1.6.1 scipy==1.15.2 sqlalchemy[asyncio]==2.0.20 uvicorn==0.23.2 +boto3==1.37.25 +pydantic[email] diff --git a/backend/tests/test.env b/backend/tests/test.env index 648493b..a222901 100644 --- a/backend/tests/test.env +++ b/backend/tests/test.env @@ -1,3 +1,4 @@ +ENV=testing PROMETHEUS_MULTIPROC_DIR=/tmp # DB connection POSTGRES_USER=postgres-test-user @@ -10,3 +11,10 @@ REDIS_HOST=redis://localhost:6381 ADMIN_USERNAME=test@idinsight.org ADMIN_PASSWORD=test123 ADMIN_API_KEY=testkey123 + +NEXT_PUBLIC_GOOGLE_LOGIN_CLIENT_ID=random_key +SES_REGION=ap-south-1 +SES_SENDER_EMAIL=no-reply@example.com +FRONTEND_URL=http://localhost:3000 +AWS_ACCESS_KEY_ID=aws_access_key +AWS_SECRET_ACCESS_KEY=aws_secret_key diff --git a/frontend/package-lock.json b/frontend/package-lock.json index c0e85ec..2d97cb8 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -408,6 +408,134 @@ "node": ">= 10" } }, + "node_modules/@next/swc-darwin-x64": { + "version": "14.2.25", + "resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-14.2.25.tgz", + "integrity": "sha512-V+iYM/QR+aYeJl3/FWWU/7Ix4b07ovsQ5IbkwgUK29pTHmq+5UxeDr7/dphvtXEq5pLB/PucfcBNh9KZ8vWbug==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@next/swc-linux-arm64-gnu": { + "version": "14.2.25", + "resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-14.2.25.tgz", + "integrity": "sha512-LFnV2899PJZAIEHQ4IMmZIgL0FBieh5keMnriMY1cK7ompR+JUd24xeTtKkcaw8QmxmEdhoE5Mu9dPSuDBgtTg==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@next/swc-linux-arm64-musl": { + "version": "14.2.25", + "resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-14.2.25.tgz", + "integrity": "sha512-QC5y5PPTmtqFExcKWKYgUNkHeHE/z3lUsu83di488nyP0ZzQ3Yse2G6TCxz6nNsQwgAx1BehAJTZez+UQxzLfw==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@next/swc-linux-x64-gnu": { + "version": "14.2.25", + "resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-14.2.25.tgz", + "integrity": "sha512-y6/ML4b9eQ2D/56wqatTJN5/JR8/xdObU2Fb1RBidnrr450HLCKr6IJZbPqbv7NXmje61UyxjF5kvSajvjye5w==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@next/swc-linux-x64-musl": { + "version": "14.2.25", + "resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-14.2.25.tgz", + "integrity": "sha512-sPX0TSXHGUOZFvv96GoBXpB3w4emMqKeMgemrSxI7A6l55VBJp/RKYLwZIB9JxSqYPApqiREaIIap+wWq0RU8w==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@next/swc-win32-arm64-msvc": { + "version": "14.2.25", + "resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-14.2.25.tgz", + "integrity": "sha512-ReO9S5hkA1DU2cFCsGoOEp7WJkhFzNbU/3VUF6XxNGUCQChyug6hZdYL/istQgfT/GWE6PNIg9cm784OI4ddxQ==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@next/swc-win32-ia32-msvc": { + "version": "14.2.25", + "resolved": "https://registry.npmjs.org/@next/swc-win32-ia32-msvc/-/swc-win32-ia32-msvc-14.2.25.tgz", + "integrity": "sha512-DZ/gc0o9neuCDyD5IumyTGHVun2dCox5TfPQI/BJTYwpSNYM3CZDI4i6TOdjeq1JMo+Ug4kPSMuZdwsycwFbAw==", + "cpu": [ + "ia32" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@next/swc-win32-x64-msvc": { + "version": "14.2.25", + "resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-14.2.25.tgz", + "integrity": "sha512-KSznmS6eFjQ9RJ1nEc66kJvtGIL1iZMYmGEXsZPh2YtnLtqrgdVvKXJY2ScjjoFnG6nGLyPFR0UiEvDwVah4Tw==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">= 10" + } + }, "node_modules/@nodelib/fs.scandir": { "version": "2.1.5", "resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz", diff --git a/frontend/src/app/forgot-password/page.tsx b/frontend/src/app/forgot-password/page.tsx new file mode 100644 index 0000000..3f048cf --- /dev/null +++ b/frontend/src/app/forgot-password/page.tsx @@ -0,0 +1,163 @@ +"use client"; + +import { motion } from "framer-motion"; +import Link from "next/link"; +import { zodResolver } from "@hookform/resolvers/zod"; +import { useForm } from "react-hook-form"; +import * as z from "zod"; +import { Button } from "@/components/ui/button"; +import { + Card, + CardContent, + CardDescription, + CardFooter, + CardHeader, + CardTitle, +} from "@/components/ui/card"; +import { + Form, + FormControl, + FormField, + FormItem, + FormLabel, + FormMessage, +} from "@/components/ui/form"; +import { Input } from "@/components/ui/input"; +import { useState } from "react"; +import { apiCalls } from "@/utils/api"; + +const formSchema = z.object({ + email: z.string().email({ + message: "Please enter a valid email address.", + }), +}); + +export default function ForgotPasswordPage() { + const form = useForm>({ + resolver: zodResolver(formSchema), + defaultValues: { + email: "", + }, + }); + + const [isSubmitting, setIsSubmitting] = useState(false); + const [success, setSuccess] = useState(false); + const [errorState, setErrorState] = useState(null); + + async function onSubmit(values: z.infer) { + setIsSubmitting(true); + setErrorState(null); + + try { + const response = await apiCalls.requestPasswordReset(values.email); + console.log("Password reset request response:", response); + setSuccess(true); + } catch (error) { + setErrorState("An error occurred while sending the reset link. Please try again later."); + console.error("Password reset request error:", error); + } finally { + setIsSubmitting(false); + } + } + + return ( +
+ + + + + Forgot Password + + + Enter your email to receive a password reset link + + + + {success ? ( +
+
+
+
+ +
+
+

+ If an account with this email exists, a password reset link has been sent. +

+
+
+
+ +
+ ) : ( +
+ + ( + +
+ Email + +
+ + + +
+ )} + /> + + + {/* Error message */} + {errorState && ( + + + {errorState} + + + )} + + + )} +
+ +

+ {"Remember your password? "} + + Sign in + +

+
+
+
+
+ ); +} diff --git a/frontend/src/app/login/page.tsx b/frontend/src/app/login/page.tsx index bb7c1df..893eb28 100644 --- a/frontend/src/app/login/page.tsx +++ b/frontend/src/app/login/page.tsx @@ -118,23 +118,31 @@ export default function LoginPage() { )} /> - ( - - - - -
- Remember me -
-
- )} - /> +
+ ( + + + + +
+ Remember me +
+
+ )} + /> + + Forgot password? + +
diff --git a/frontend/src/app/reset-password/page.tsx b/frontend/src/app/reset-password/page.tsx new file mode 100644 index 0000000..c589ac1 --- /dev/null +++ b/frontend/src/app/reset-password/page.tsx @@ -0,0 +1,245 @@ +"use client"; + +import { motion } from "framer-motion"; +import Link from "next/link"; +import { zodResolver } from "@hookform/resolvers/zod"; +import { useForm } from "react-hook-form"; +import * as z from "zod"; +import { Button } from "@/components/ui/button"; +import { + Card, + CardContent, + CardDescription, + CardFooter, + CardHeader, + CardTitle, +} from "@/components/ui/card"; +import { + Form, + FormControl, + FormField, + FormItem, + FormLabel, + FormMessage, +} from "@/components/ui/form"; +import { Input } from "@/components/ui/input"; +import { useEffect, useState } from "react"; +import { apiCalls } from "@/utils/api"; +import { useRouter, useSearchParams } from "next/navigation"; + +const formSchema = z + .object({ + password: z.string().min(4, { + message: "Password must be at least 4 characters long.", + }), + confirm_password: z.string().min(4, { + message: "Password must be at least 4 characters long.", + }), + }) + .refine((data) => data.password === data.confirm_password, { + message: "Passwords do not match", + path: ["confirm_password"], + }); + +export default function ResetPasswordPage() { + const router = useRouter(); + const searchParams = useSearchParams(); + + const [token, setToken] = useState(null); + const [isSubmitting, setIsSubmitting] = useState(false); + const [success, setSuccess] = useState(false); + const [tokenError, setTokenError] = useState(false); + const [errorState, setErrorState] = useState(null); + + const form = useForm>({ + resolver: zodResolver(formSchema), + defaultValues: { + password: "", + confirm_password: "", + }, + }); + + useEffect(() => { + const tokenParam = searchParams?.get("token"); + if (!tokenParam) { + setTokenError(true); + } else { + setToken(tokenParam); + } + }, [searchParams]); + + async function onSubmit(values: z.infer) { + if (!token) { + setTokenError(true); + return; + } + + setIsSubmitting(true); + setErrorState(null); + + try { + const response = await apiCalls.resetPassword(token, values.password); + setSuccess(true); + + // Redirect to login page after 3 seconds + setTimeout(() => { + router.push("/login"); + }, 3000); + } catch (error) { + setErrorState("Failed to reset password. The link may have expired or is invalid."); + console.error("Password reset error:", error); + } finally { + setIsSubmitting(false); + } + } + + return ( +
+ + + + + Reset Password + + + Enter your new password + + + + {tokenError ? ( +
+
+
+
+ +
+
+

+ Missing or invalid reset token. Please request a new password reset link. +

+
+
+
+ +
+ ) : success ? ( +
+
+
+
+ +
+
+

+ Your password has been reset successfully. Redirecting to login... +

+
+
+
+
+ ) : ( +
+ + ( + +
+ New Password + +
+ + + +
+ )} + /> + ( + +
+ Confirm New Password + +
+ + + +
+ )} + /> + + + {/* Error message */} + {errorState && ( + + + {errorState} + + + )} + + + )} +
+ +

+ {"Remember your password? "} + + Sign in + +

+
+
+
+
+ ); +} diff --git a/frontend/src/app/verification-required/page.tsx b/frontend/src/app/verification-required/page.tsx new file mode 100644 index 0000000..b7b0955 --- /dev/null +++ b/frontend/src/app/verification-required/page.tsx @@ -0,0 +1,171 @@ +"use client"; + +import { motion } from "framer-motion"; +import Link from "next/link"; +import { Button } from "@/components/ui/button"; +import { + Card, + CardContent, + CardDescription, + CardFooter, + CardHeader, + CardTitle, +} from "@/components/ui/card"; +import { useState } from "react"; +import { apiCalls } from "@/utils/api"; +import { useAuth } from "@/utils/auth"; + +export default function VerificationRequiredPage() { + const { user, logout } = useAuth(); + const [isResending, setIsResending] = useState(false); + const [resendSuccess, setResendSuccess] = useState(false); + const [resendError, setResendError] = useState(null); + + const handleResendVerification = async () => { + if (!user) { + setResendError("User information not available. Please try logging in again."); + return; + } + + setIsResending(true); + setResendError(null); + + try { + await apiCalls.resendVerification(user); + setResendSuccess(true); + } catch (error) { + setResendError("Failed to resend verification email. Please try again later."); + console.error("Resend verification error:", error); + } finally { + setIsResending(false); + } + }; + + return ( +
+ + + + + Email Verification Required + + + Please verify your email address to continue + + + +
+
+
+ +
+
+

+ Your account requires email verification. We've sent a verification link to your email address. +

+
+
+
+ + {resendSuccess && ( +
+
+
+ +
+
+

+ Verification email has been resent. Please check your inbox. +

+
+
+
+ )} + + {resendError && ( +
+
+
+ +
+
+

+ {resendError} +

+
+
+
+ )} + +
+ + +
+
+ +

+ Already verified? Try{" "} + + logging in again + +

+
+
+
+
+ ); +} diff --git a/frontend/src/app/verify/page.tsx b/frontend/src/app/verify/page.tsx new file mode 100644 index 0000000..2273ba1 --- /dev/null +++ b/frontend/src/app/verify/page.tsx @@ -0,0 +1,159 @@ +"use client"; + +import { motion } from "framer-motion"; +import Link from "next/link"; +import { Button } from "@/components/ui/button"; +import { + Card, + CardContent, + CardDescription, + CardFooter, + CardHeader, + CardTitle, +} from "@/components/ui/card"; +import { useEffect, useState } from "react"; +import { apiCalls } from "@/utils/api"; +import { useRouter, useSearchParams } from "next/navigation"; + +export default function VerifyEmailPage() { + const router = useRouter(); + const searchParams = useSearchParams(); + + const [token, setToken] = useState(null); + const [status, setStatus] = useState<'loading' | 'success' | 'error'>('loading'); + const [errorMessage, setErrorMessage] = useState(''); + + useEffect(() => { + const tokenParam = searchParams?.get("token"); + if (!tokenParam) { + setStatus('error'); + setErrorMessage('Missing verification token.'); + return; + } + + setToken(tokenParam); + verifyEmail(tokenParam); + }, [searchParams]); + + const verifyEmail = async (token: string) => { + try { + const response = await apiCalls.verifyEmail(token); + setStatus('success'); + + // Redirect to login page after 3 seconds + setTimeout(() => { + router.push("/login"); + }, 3000); + } catch (error) { + setStatus('error'); + setErrorMessage('Failed to verify email. The link may have expired or is invalid.'); + console.error("Email verification error:", error); + } + }; + + const handleResendVerification = () => { + router.push("/resend-verification"); + }; + + return ( +
+ + + + + Email Verification + + + {status === 'loading' ? 'Verifying your email address...' : + status === 'success' ? 'Email verified successfully!' : + 'Email verification failed'} + + + + {status === 'loading' && ( +
+
+
+ )} + + {status === 'success' && ( +
+
+
+
+ +
+
+

+ Your email has been verified successfully! You can now log in. +

+
+
+
+

Redirecting to login page...

+
+ )} + + {status === 'error' && ( +
+
+
+
+ +
+
+

+ {errorMessage} +

+
+
+
+ +
+ )} +
+ +

+ {"Return to "} + + Sign in + +

+
+
+
+
+ ); +} diff --git a/frontend/src/components/ProtectedComponent.tsx b/frontend/src/components/ProtectedComponent.tsx index 9cc3cd9..dc7716e 100644 --- a/frontend/src/components/ProtectedComponent.tsx +++ b/frontend/src/components/ProtectedComponent.tsx @@ -5,27 +5,35 @@ import { usePathname, useRouter } from "next/navigation"; interface ProtectedComponentProps { children: React.ReactNode; + requireVerified?: boolean; } const ProtectedComponent: React.FC = ({ children, + requireVerified = true, }) => { const router = useRouter(); - const { token } = useAuth(); + const { token, isVerified } = useAuth(); const pathname = usePathname(); const [isClient, setIsClient] = useState(false); useEffect(() => { if (!token) { router.push("/login?sourcePage=" + encodeURIComponent(pathname)); + return; } - }, [token, pathname, router]); + + if (requireVerified && !isVerified) { + router.push("/verification-required"); + } + }, [token, isVerified, requireVerified, pathname, router]); // This is to prevent the page from starting to load the children before the token is checked useEffect(() => { setIsClient(true); }, []); - if (!token || !isClient) { + + if (!token || (requireVerified && !isVerified) || !isClient) { return null; } else { return <>{children}; diff --git a/frontend/src/utils/api.ts b/frontend/src/utils/api.ts index 86f8726..9a60457 100644 --- a/frontend/src/utils/api.ts +++ b/frontend/src/utils/api.ts @@ -91,10 +91,53 @@ const registerUser = async (username: string, password: string) => { } }; +const requestPasswordReset = async (username: string) => { + try { + const response = await api.post("/request-password-reset", { username }); + return response.data; + } catch (error) { + throw new Error("Error requesting password reset"); + } +}; + +const resetPassword = async (token: string, newPassword: string) => { + try { + const response = await api.post("/reset-password", { + token, + new_password: newPassword + }); + return response.data; + } catch (error) { + throw new Error("Error resetting password"); + } +}; + +const verifyEmail = async (token: string) => { + try { + const response = await api.post("/verify-email", { token }); + return response.data; + } catch (error) { + throw new Error("Error verifying email"); + } +}; + +const resendVerification = async (username: string) => { + try { + const response = await api.post("/resend-verification", { username }); + return response.data; + } catch (error) { + throw new Error("Error resending verification email"); + } +}; + export const apiCalls = { getUser, getLoginToken, getGoogleLoginToken, registerUser, + requestPasswordReset, + resetPassword, + verifyEmail, + resendVerification, }; export default api; diff --git a/frontend/src/utils/auth.tsx b/frontend/src/utils/auth.tsx index 015b122..805a3cb 100644 --- a/frontend/src/utils/auth.tsx +++ b/frontend/src/utils/auth.tsx @@ -1,11 +1,12 @@ "use client"; import { apiCalls } from "@/utils/api"; import { useRouter, useSearchParams } from "next/navigation"; -import { ReactNode, createContext, useContext, useState } from "react"; +import { ReactNode, createContext, useContext, useState, useEffect } from "react"; type AuthContextType = { token: string | null; user: string | null; + isVerified: boolean; login: (username: string, password: string) => void; logout: () => void; loginError: string | null; @@ -25,7 +26,6 @@ type AuthProviderProps = { }; const AuthProvider = ({ children }: AuthProviderProps) => { - const getInitialToken = () => { if (typeof window !== "undefined") { return localStorage.getItem("ee-token"); @@ -42,25 +42,64 @@ const AuthProvider = ({ children }: AuthProviderProps) => { const [user, setUser] = useState(getInitialUsername); const [token, setToken] = useState(getInitialToken); + const [isVerified, setIsVerified] = useState(false); const [loginError, setLoginError] = useState(null); const searchParams = useSearchParams(); const router = useRouter(); + // Check verification status on init if token exists + useEffect(() => { + const checkVerificationStatus = async () => { + const currentToken = getInitialToken(); + if (currentToken) { + try { + const userData = await apiCalls.getUser(currentToken); + setIsVerified(userData.is_verified); + } catch (error) { + console.error("Error fetching user verification status:", error); + } + } + }; + + checkVerificationStatus(); + }, []); + const login = async (username: string, password: string) => { const sourcePage = searchParams.has("sourcePage") ? decodeURIComponent(searchParams.get("sourcePage") as string) : "/"; try { - const { access_token } = await apiCalls.getLoginToken(username, password); + const response = await apiCalls.getLoginToken(username, password); + const { access_token } = response; + localStorage.setItem("ee-token", access_token); localStorage.setItem("ee-username", username); setUser(username); setToken(access_token); setLoginError(null); - router.push(sourcePage); + + // Check if verification status is in the response + if (response.is_verified !== undefined) { + setIsVerified(response.is_verified); + } else { + // If not in response, fetch user data to get verification status + try { + const userData = await apiCalls.getUser(access_token); + setIsVerified(userData.is_verified); + } catch (error) { + console.error("Error fetching user verification status:", error); + } + } + + // Redirect to verification page if not verified, otherwise to original destination + if (response.is_verified === false) { + router.push("/verification-required"); + } else { + router.push(sourcePage); + } } catch (error: unknown) { if ( error && @@ -86,27 +125,27 @@ const AuthProvider = ({ children }: AuthProviderProps) => { ? decodeURIComponent(searchParams.get("sourcePage") as string) : "/"; - apiCalls - .getGoogleLoginToken({ client_id: client_id, credential: credential }) - .then( - ({ - access_token, - username, - }: { - access_token: string; - username: string; - }) => { - localStorage.setItem("ee-token", access_token); - localStorage.setItem("ee-username", username); - setUser(username); - setToken(access_token); - router.push(sourcePage); - } - ) - .catch((error: Error) => { - setLoginError("Invalid Google credentials"); - console.error("Google login error:", error); + try { + const response = await apiCalls.getGoogleLoginToken({ + client_id: client_id, + credential: credential }); + + const { access_token, username } = response; + + localStorage.setItem("ee-token", access_token); + localStorage.setItem("ee-username", username); + + setUser(username); + setToken(access_token); + + setIsVerified(true); + + router.push(sourcePage); + } catch (error) { + setLoginError("Invalid Google credentials"); + console.error("Google login error:", error); + } }; const logout = () => { @@ -115,16 +154,18 @@ const AuthProvider = ({ children }: AuthProviderProps) => { localStorage.removeItem("ee-username"); setUser(null); setToken(null); + setIsVerified(false); router.push("/login"); }; const authValue: AuthContextType = { - token: token, - user: user, - login: login, - loginError: loginError, - loginGoogle: loginGoogle, - logout: logout, + token, + user, + isVerified, + login, + loginError, + loginGoogle, + logout, }; return ( @@ -141,3 +182,16 @@ export const useAuth = () => { } return context; }; + +export const useRequireVerified = () => { + const { token, isVerified } = useAuth(); + const router = useRouter(); + + useEffect(() => { + if (token && !isVerified) { + router.push("/verification-required"); + } + }, [token, isVerified, router]); + + return { token, isVerified }; +}; From 3e862414bedd455cac2c8fba644451ea930df25b Mon Sep 17 00:00:00 2001 From: Jay Prakash <0freerunning@gmail.com> Date: Thu, 3 Apr 2025 04:01:47 +0530 Subject: [PATCH 02/74] Added env variables for test --- .github/workflows/unit_tests.yaml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.github/workflows/unit_tests.yaml b/.github/workflows/unit_tests.yaml index 0c4ffd5..c9286b3 100644 --- a/.github/workflows/unit_tests.yaml +++ b/.github/workflows/unit_tests.yaml @@ -15,6 +15,12 @@ env: ADMIN_USERNAME: test@idinsight.org ADMIN_PASSWORD: test123 ADMIN_API_KEY: testkey123 + SES_REGION: ap-south-1 + SES_SENDER_EMAIL: no-reply@example.com + FRONTEND_URL: http://localhost:3000 + AWS_ACCESS_KEY_ID: aws_access_key + AWS_SECRET_ACCESS_KEY: aws_secret_key + ENV: testing jobs: container-job: runs-on: ubuntu-22.04 From b183861e5036c8bdb4195e30b5b43e06c280389a Mon Sep 17 00:00:00 2001 From: Jay Prakash <0freerunning@gmail.com> Date: Fri, 4 Apr 2025 04:14:05 +0530 Subject: [PATCH 03/74] Fix redirect issue --- .../src/components/ProtectedComponent.tsx | 29 ++++++++++--------- frontend/src/utils/auth.tsx | 21 +++++++++++--- 2 files changed, 32 insertions(+), 18 deletions(-) diff --git a/frontend/src/components/ProtectedComponent.tsx b/frontend/src/components/ProtectedComponent.tsx index dc7716e..f15a761 100644 --- a/frontend/src/components/ProtectedComponent.tsx +++ b/frontend/src/components/ProtectedComponent.tsx @@ -13,27 +13,28 @@ const ProtectedComponent: React.FC = ({ requireVerified = true, }) => { const router = useRouter(); - const { token, isVerified } = useAuth(); + const { token, isVerified, isLoading } = useAuth(); const pathname = usePathname(); const [isClient, setIsClient] = useState(false); - useEffect(() => { - if (!token) { - router.push("/login?sourcePage=" + encodeURIComponent(pathname)); - return; - } - - if (requireVerified && !isVerified) { - router.push("/verification-required"); - } - }, [token, isVerified, requireVerified, pathname, router]); - - // This is to prevent the page from starting to load the children before the token is checked useEffect(() => { setIsClient(true); }, []); - if (!token || (requireVerified && !isVerified) || !isClient) { + useEffect(() => { + if (isClient && !isLoading) { + if (!token) { + router.push("/login?sourcePage=" + encodeURIComponent(pathname)); + return; + } + + if (requireVerified && !isVerified) { + router.push("/verification-required"); + } + } + }, [token, isVerified, isLoading, requireVerified, pathname, router, isClient]); + + if (!isClient || isLoading || !token || (requireVerified && !isVerified)) { return null; } else { return <>{children}; diff --git a/frontend/src/utils/auth.tsx b/frontend/src/utils/auth.tsx index 805a3cb..ee07bad 100644 --- a/frontend/src/utils/auth.tsx +++ b/frontend/src/utils/auth.tsx @@ -7,6 +7,7 @@ type AuthContextType = { token: string | null; user: string | null; isVerified: boolean; + isLoading: boolean; login: (username: string, password: string) => void; logout: () => void; loginError: string | null; @@ -43,6 +44,7 @@ const AuthProvider = ({ children }: AuthProviderProps) => { const [user, setUser] = useState(getInitialUsername); const [token, setToken] = useState(getInitialToken); const [isVerified, setIsVerified] = useState(false); + const [isLoading, setIsLoading] = useState(!!getInitialToken()); const [loginError, setLoginError] = useState(null); const searchParams = useSearchParams(); @@ -53,11 +55,15 @@ const AuthProvider = ({ children }: AuthProviderProps) => { const checkVerificationStatus = async () => { const currentToken = getInitialToken(); if (currentToken) { + setIsLoading(true); try { const userData = await apiCalls.getUser(currentToken); setIsVerified(userData.is_verified); } catch (error) { console.error("Error fetching user verification status:", error); + logout(); + } finally { + setIsLoading(false); } } }; @@ -71,6 +77,7 @@ const AuthProvider = ({ children }: AuthProviderProps) => { : "/"; try { + setIsLoading(true); const response = await apiCalls.getLoginToken(username, password); const { access_token } = response; @@ -111,6 +118,8 @@ const AuthProvider = ({ children }: AuthProviderProps) => { } else { setLoginError("An unexpected error occurred. Please try again later."); } + } finally { + setIsLoading(false); } }; @@ -126,6 +135,7 @@ const AuthProvider = ({ children }: AuthProviderProps) => { : "/"; try { + setIsLoading(true); const response = await apiCalls.getGoogleLoginToken({ client_id: client_id, credential: credential @@ -145,6 +155,8 @@ const AuthProvider = ({ children }: AuthProviderProps) => { } catch (error) { setLoginError("Invalid Google credentials"); console.error("Google login error:", error); + } finally { + setIsLoading(false); } }; @@ -162,6 +174,7 @@ const AuthProvider = ({ children }: AuthProviderProps) => { token, user, isVerified, + isLoading, login, loginError, loginGoogle, @@ -184,14 +197,14 @@ export const useAuth = () => { }; export const useRequireVerified = () => { - const { token, isVerified } = useAuth(); + const { token, isVerified, isLoading } = useAuth(); const router = useRouter(); useEffect(() => { - if (token && !isVerified) { + if (!isLoading && token && !isVerified) { router.push("/verification-required"); } - }, [token, isVerified, router]); + }, [token, isVerified, isLoading, router]); - return { token, isVerified }; + return { token, isVerified, isLoading }; }; From c79830bd71950cef942f76874304d8d9bfabf31c Mon Sep 17 00:00:00 2001 From: Jay Prakash <0freerunning@gmail.com> Date: Tue, 22 Apr 2025 03:04:43 +0530 Subject: [PATCH 04/74] Adding the workspace feature --- backend/app/__init__.py | 2 + backend/app/auth/dependencies.py | 7 +- backend/app/auth/routers.py | 47 +- backend/app/auth/utils.py | 8 - backend/app/contextual_mab/models.py | 20 +- backend/app/contextual_mab/routers.py | 82 ++- backend/app/contextual_mab/schemas.py | 1 + backend/app/email.py | 59 +++ backend/app/mab/models.py | 19 +- backend/app/mab/routers.py | 98 +++- backend/app/mab/schemas.py | 1 + backend/app/models.py | 7 +- backend/app/users/models.py | 3 +- backend/app/users/routers.py | 14 +- backend/app/users/schemas.py | 3 + backend/app/workspaces/__init__.py | 0 backend/app/workspaces/models.py | 303 +++++++++++ backend/app/workspaces/routers.py | 479 ++++++++++++++++++ backend/app/workspaces/schemas.py | 111 ++++ backend/app/workspaces/utils.py | 141 ++++++ .../949c9fc0461d_workspace_relationship.py | 32 ++ .../versions/977e7e73ce06_workspace_model.py | 56 ++ .../app/(protected)/workspace/create/page.tsx | 131 +++++ .../app/(protected)/workspace/invite/page.tsx | 193 +++++++ .../src/app/(protected)/workspace/page.tsx | 116 +++++ .../src/app/(protected)/workspace/types.ts | 51 ++ frontend/src/components/WorkspaceSelector.tsx | 142 ++++++ frontend/src/components/sidebar.tsx | 23 +- frontend/src/utils/api.ts | 90 +++- frontend/src/utils/auth.tsx | 82 ++- 30 files changed, 2256 insertions(+), 65 deletions(-) create mode 100644 backend/app/workspaces/__init__.py create mode 100644 backend/app/workspaces/models.py create mode 100644 backend/app/workspaces/routers.py create mode 100644 backend/app/workspaces/schemas.py create mode 100644 backend/app/workspaces/utils.py create mode 100644 backend/migrations/versions/949c9fc0461d_workspace_relationship.py create mode 100644 backend/migrations/versions/977e7e73ce06_workspace_model.py create mode 100644 frontend/src/app/(protected)/workspace/create/page.tsx create mode 100644 frontend/src/app/(protected)/workspace/invite/page.tsx create mode 100644 frontend/src/app/(protected)/workspace/page.tsx create mode 100644 frontend/src/app/(protected)/workspace/types.ts create mode 100644 frontend/src/components/WorkspaceSelector.tsx diff --git a/backend/app/__init__.py b/backend/app/__init__.py index 845b9c2..2653972 100644 --- a/backend/app/__init__.py +++ b/backend/app/__init__.py @@ -10,6 +10,7 @@ from .users.routers import ( router as users_router, ) # to avoid circular imports +from .workspaces.routers import router as workspaces_router from .utils import setup_logger logger = setup_logger() @@ -40,6 +41,7 @@ def create_app() -> FastAPI: app.include_router(auth.router) app.include_router(users_router) app.include_router(messages.router) + app.include_router(workspaces_router) origins = [ "http://localhost", diff --git a/backend/app/auth/dependencies.py b/backend/app/auth/dependencies.py index 534c87f..d34a593 100644 --- a/backend/app/auth/dependencies.py +++ b/backend/app/auth/dependencies.py @@ -12,7 +12,7 @@ from jwt.exceptions import InvalidTokenError from sqlalchemy.ext.asyncio import AsyncSession -from ..config import CHECK_API_LIMIT, DEFAULT_API_QUOTA, DEFAULT_EXPERIMENTS_QUOTA, ENV +from ..config import CHECK_API_LIMIT, DEFAULT_API_QUOTA, DEFAULT_EXPERIMENTS_QUOTA from ..database import get_async_session from ..users.models import ( UserDB, @@ -185,7 +185,7 @@ async def get_verified_user( return user_db -def create_access_token(username: str) -> str: +def create_access_token(username: str, workspace_name: str = None) -> str: """ Create an access token for the user """ @@ -198,6 +198,9 @@ def create_access_token(username: str) -> str: payload["iat"] = datetime.now(timezone.utc) payload["sub"] = username payload["type"] = "access_token" + + if workspace_name: + payload["workspace_name"] = workspace_name return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM) diff --git a/backend/app/auth/routers.py b/backend/app/auth/routers.py index eec0c4a..c989ae8 100644 --- a/backend/app/auth/routers.py +++ b/backend/app/auth/routers.py @@ -6,6 +6,8 @@ from redis.asyncio import Redis from sqlalchemy.ext.asyncio import AsyncSession +from ..config import DEFAULT_API_QUOTA, DEFAULT_EXPERIMENTS_QUOTA + from ..database import get_async_session, get_redis from ..email import EmailService from ..users.models import ( @@ -19,6 +21,7 @@ MessageResponse, PasswordResetConfirm, PasswordResetRequest, + UserCreate, ) from ..utils import setup_logger from .config import NEXT_PUBLIC_GOOGLE_LOGIN_CLIENT_ID @@ -96,8 +99,17 @@ async def login_google( except ValueError as e: raise HTTPException(status_code=401, detail="Invalid token") from e + # Import here to avoid circular imports + from ..workspaces.models import ( + create_user_workspace_role, + get_user_default_workspace, + UserRoles + ) + from ..workspaces.utils import create_workspace + + user_email = idinfo["email"] user = await authenticate_or_create_google_user( - request=request, google_email=idinfo["email"], asession=asession + request=request, google_email=user_email, asession=asession ) if not user: raise HTTPException( @@ -105,8 +117,39 @@ async def login_google( detail="Unable to create new user", ) + user_db = await get_user_by_username(username=user_email, asession=asession) + + # Create default workspace if user is new (has no workspaces) + try: + default_workspace = await get_user_default_workspace(asession=asession, user_db=user_db) + default_workspace_name = default_workspace.workspace_name + except Exception: + # User doesn't have a default workspace, create one + default_workspace_name = f"{user_email}'s Workspace" + + # Create default workspace + workspace_db, _ = await create_workspace( + api_daily_quota=DEFAULT_API_QUOTA, + asession=asession, + content_quota=DEFAULT_EXPERIMENTS_QUOTA, + user=UserCreate( + role=UserRoles.ADMIN, + username=user_email, + workspace_name=default_workspace_name, + ), + is_default=True + ) + + await create_user_workspace_role( + asession=asession, + is_default_workspace=True, + user_db=user_db, + user_role=UserRoles.ADMIN, + workspace_db=workspace_db, + ) + return AuthenticationDetails( - access_token=create_access_token(user.username), + access_token=create_access_token(user.username, default_workspace_name), api_key_first_characters=user.api_key_first_characters, token_type="bearer", access_level=user.access_level, diff --git a/backend/app/auth/utils.py b/backend/app/auth/utils.py index cf8c15a..05b8f69 100644 --- a/backend/app/auth/utils.py +++ b/backend/app/auth/utils.py @@ -6,20 +6,12 @@ from redis.asyncio import Redis from ..utils import setup_logger -<<<<<<< HEAD -from .config import JWT_ALGORITHM, JWT_SECRET - -# Token settings -VERIFICATION_TOKEN_EXPIRE_MINUTES = 1440 # 24 hours -PASSWORD_RESET_TOKEN_EXPIRE_MINUTES = 30 # 30 minutes -======= from .config import ( JWT_ALGORITHM, JWT_SECRET, PASSWORD_RESET_TOKEN_EXPIRE_MINUTES, VERIFICATION_TOKEN_EXPIRE_MINUTES, ) ->>>>>>> origin logger = setup_logger() diff --git a/backend/app/contextual_mab/models.py b/backend/app/contextual_mab/models.py index ca27a8b..da4d699 100644 --- a/backend/app/contextual_mab/models.py +++ b/backend/app/contextual_mab/models.py @@ -59,6 +59,7 @@ def to_dict(self) -> dict: return { "experiment_id": self.experiment_id, "user_id": self.user_id, + "workspace_id": self.workspace_id, "name": self.name, "description": self.description, "created_datetime_utc": self.created_datetime_utc, @@ -191,6 +192,7 @@ def to_dict(self) -> dict: async def save_contextual_mab_to_db( experiment: ContextualBandit, user_id: int, + workspace_id: int, asession: AsyncSession, ) -> ContextualBanditDB: """ @@ -225,6 +227,7 @@ async def save_contextual_mab_to_db( name=experiment.name, description=experiment.description, user_id=user_id, + workspace_id=workspace_id, is_active=experiment.is_active, created_datetime_utc=datetime.now(timezone.utc), n_trials=0, @@ -243,15 +246,17 @@ async def save_contextual_mab_to_db( async def get_all_contextual_mabs( user_id: int, + workspace_id: int, asession: AsyncSession, ) -> Sequence[ContextualBanditDB]: """ - Get all the contextual experiments from the database. + Get all the contextual experiments from the database for a specific workspace. """ statement = ( select(ContextualBanditDB) .where( ContextualBanditDB.user_id == user_id, + ContextualBanditDB.workspace_id == workspace_id, ) .order_by(ContextualBanditDB.experiment_id) ) @@ -260,7 +265,10 @@ async def get_all_contextual_mabs( async def get_contextual_mab_by_id( - experiment_id: int, user_id: int, asession: AsyncSession + experiment_id: int, + user_id: int, + workspace_id: int, + asession: AsyncSession ) -> ContextualBanditDB | None: """ Get the contextual experiment by id. @@ -268,6 +276,7 @@ async def get_contextual_mab_by_id( result = await asession.execute( select(ContextualBanditDB) .where(ContextualBanditDB.user_id == user_id) + .where(ContextualBanditDB.workspace_id == workspace_id) .where(ContextualBanditDB.experiment_id == experiment_id) ) @@ -275,7 +284,10 @@ async def get_contextual_mab_by_id( async def delete_contextual_mab_by_id( - experiment_id: int, user_id: int, asession: AsyncSession + experiment_id: int, + user_id: int, + workspace_id: int, + asession: AsyncSession ) -> None: """ Delete the contextual experiment by id. @@ -289,7 +301,6 @@ async def delete_contextual_mab_by_id( await asession.execute( delete(ContextualObservationDB).where( and_( - ContextualObservationDB.user_id == ObservationsBaseDB.user_id, ContextualObservationDB.user_id == user_id, ContextualObservationDB.experiment_id == experiment_id, ) @@ -317,6 +328,7 @@ async def delete_contextual_mab_by_id( and_( ContextualBanditDB.experiment_id == ExperimentBaseDB.experiment_id, ContextualBanditDB.user_id == user_id, + ContextualBanditDB.workspace_id == workspace_id, ContextualBanditDB.experiment_id == experiment_id, ) ) diff --git a/backend/app/contextual_mab/routers.py b/backend/app/contextual_mab/routers.py index 95a6524..c56aa32 100644 --- a/backend/app/contextual_mab/routers.py +++ b/backend/app/contextual_mab/routers.py @@ -4,6 +4,9 @@ from fastapi.exceptions import HTTPException from sqlalchemy.ext.asyncio import AsyncSession +from ..workspaces.models import get_user_default_workspace, get_user_role_in_workspace +from ..workspaces.schemas import UserRoles + from ..auth.dependencies import authenticate_key, get_verified_user from ..database import get_async_session from ..models import get_notifications_from_db, save_notifications_to_db @@ -41,7 +44,24 @@ async def create_contextual_mabs( """ Create a new contextual experiment with different priors for each context. """ - cmab = await save_contextual_mab_to_db(experiment, user_db.user_id, asession) + workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) + + user_role = await get_user_role_in_workspace( + asession=asession, user_db=user_db, workspace_db=workspace_db + ) + + if user_role != UserRoles.ADMIN: + raise HTTPException( + status_code=403, + detail="Only workspace administrators can create experiments.", + ) + + cmab = await save_contextual_mab_to_db( + experiment, + user_db.user_id, + workspace_db.workspace_id, + asession + ) notifications = await save_notifications_to_db( experiment_id=cmab.experiment_id, user_id=user_db.user_id, @@ -61,7 +81,13 @@ async def get_contextual_mabs( """ Get details of all experiments. """ - experiments = await get_all_contextual_mabs(user_db.user_id, asession) + workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) + + experiments = await get_all_contextual_mabs( + user_db.user_id, + workspace_db.workspace_id, + asession + ) all_experiments = [] for exp in experiments: exp_dict = exp.to_dict() @@ -95,8 +121,13 @@ async def get_contextual_mab( """ Get details of experiment with the provided `experiment_id`. """ + workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) + experiment = await get_contextual_mab_by_id( - experiment_id, user_db.user_id, asession + experiment_id, + user_db.user_id, + workspace_db.workspace_id, + asession ) if experiment is None: raise HTTPException( @@ -123,14 +154,34 @@ async def delete_contextual_mab( Delete the experiment with the provided `experiment_id`. """ try: + workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) + + user_role = await get_user_role_in_workspace( + asession=asession, user_db=user_db, workspace_db=workspace_db + ) + + if user_role != UserRoles.ADMIN: + raise HTTPException( + status_code=403, + detail="Only workspace administrators can delete experiments.", + ) + experiment = await get_contextual_mab_by_id( - experiment_id, user_db.user_id, asession + experiment_id, + user_db.user_id, + workspace_db.workspace_id, + asession ) if experiment is None: raise HTTPException( status_code=404, detail=f"Experiment with id {experiment_id} not found" ) - await delete_contextual_mab_by_id(experiment_id, user_db.user_id, asession) + await delete_contextual_mab_by_id( + experiment_id, + user_db.user_id, + workspace_db.workspace_id, + asession + ) return {"detail": f"Experiment {experiment_id} deleted successfully."} except Exception as e: raise HTTPException(status_code=500, detail=f"Error: {e}") from e @@ -146,8 +197,13 @@ async def draw_arm( """ Get which arm to pull next for provided experiment. """ + workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) + experiment = await get_contextual_mab_by_id( - experiment_id, user_db.user_id, asession + experiment_id, + user_db.user_id, + workspace_db.workspace_id, + asession ) if experiment is None: @@ -190,9 +246,14 @@ async def update_arm( Update the arm with the provided `arm_id` for the given `experiment_id` based on the `outcome`. """ + workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) + # Get the experiment and do checks experiment = await get_contextual_mab_by_id( - experiment_id, user_db.user_id, asession + experiment_id, + user_db.user_id, + workspace_db.workspace_id, + asession ) if experiment is None: raise HTTPException( @@ -284,8 +345,13 @@ async def get_outcomes( """ Get the outcomes for the experiment. """ + workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) + experiment = await get_contextual_mab_by_id( - experiment_id, user_db.user_id, asession + experiment_id, + user_db.user_id, + workspace_db.workspace_id, + asession ) if not experiment: raise HTTPException( diff --git a/backend/app/contextual_mab/schemas.py b/backend/app/contextual_mab/schemas.py index 79f1ee6..e09a8ba 100644 --- a/backend/app/contextual_mab/schemas.py +++ b/backend/app/contextual_mab/schemas.py @@ -171,6 +171,7 @@ class ContextualBanditResponse(ContextualBanditBase): """ experiment_id: int + workspace_id: int arms: list[ContextualArmResponse] contexts: list[ContextResponse] notifications: list[NotificationsResponse] diff --git a/backend/app/email.py b/backend/app/email.py index 5776735..9770090 100644 --- a/backend/app/email.py +++ b/backend/app/email.py @@ -124,6 +124,65 @@ async def send_password_reset_email( return await self._send_email(email, subject, html_body, text_body) + async def send_workspace_invitation_email( + self, email: str, username: str, inviter_email: str, workspace_name: str, user_exists: bool + ) -> Dict[str, Any]: + """ + Send workspace invitation email + """ + if user_exists: + subject = f"You've been invited to join a workspace: {workspace_name}" + html_body = f""" + + + +

Workspace Invitation

+

Hello {username},

+

You have been invited by {inviter_email} to join the workspace "{workspace_name}".

+

You have been added to this workspace. Log in to access it.

+

Login to Your Account

+ + + """ + text_body = f""" + Workspace Invitation + + Hello {username}, + + You have been invited by {inviter_email} to join the workspace "{workspace_name}". + + You have been added to this workspace. Log in to access it. + + {FRONTEND_URL}/login + """ + else: + subject = f"Invitation to Create an Account and Join a Workspace: {workspace_name}" + html_body = f""" + + + +

Workspace Invitation

+

Hello,

+

You have been invited by {inviter_email} to join the workspace "{workspace_name}".

+

You need to create an account to join this workspace.

+

Create Your Account

+ + + """ + text_body = f""" + Workspace Invitation + + Hello, + + You have been invited by {inviter_email} to join the workspace "{workspace_name}". + + You need to create an account to join this workspace. + + {FRONTEND_URL}/signup + """ + + return await self._send_email(email, subject, html_body, text_body) + async def _send_email( self, recipient: str, subject: str, html_body: str, text_body: str ) -> Dict[str, Any]: diff --git a/backend/app/mab/models.py b/backend/app/mab/models.py index 799e293..377498d 100644 --- a/backend/app/mab/models.py +++ b/backend/app/mab/models.py @@ -50,6 +50,7 @@ def to_dict(self) -> dict: return { "experiment_id": self.experiment_id, "user_id": self.user_id, + "workspace_id": self.workspace_id, "name": self.name, "description": self.description, "created_datetime_utc": self.created_datetime_utc, @@ -143,6 +144,7 @@ def to_dict(self) -> dict: async def save_mab_to_db( experiment: MultiArmedBandit, user_id: int, + workspace_id: int, asession: AsyncSession, ) -> MultiArmedBanditDB: """ @@ -159,6 +161,7 @@ async def save_mab_to_db( name=experiment.name, description=experiment.description, user_id=user_id, + workspace_id=workspace_id, is_active=experiment.is_active, created_datetime_utc=datetime.now(timezone.utc), n_trials=0, @@ -176,15 +179,17 @@ async def save_mab_to_db( async def get_all_mabs( user_id: int, + workspace_id: int, asession: AsyncSession, ) -> Sequence[MultiArmedBanditDB]: """ - Get all the experiments from the database. + Get all the experiments from the database for a specific workspace. """ statement = ( select(MultiArmedBanditDB) .where( MultiArmedBanditDB.user_id == user_id, + MultiArmedBanditDB.workspace_id == workspace_id, ) .order_by(MultiArmedBanditDB.experiment_id) ) @@ -193,7 +198,10 @@ async def get_all_mabs( async def get_mab_by_id( - experiment_id: int, user_id: int, asession: AsyncSession + experiment_id: int, + user_id: int, + workspace_id: int, + asession: AsyncSession ) -> MultiArmedBanditDB | None: """ Get the experiment by id. @@ -201,6 +209,7 @@ async def get_mab_by_id( result = await asession.execute( select(MultiArmedBanditDB) .where(MultiArmedBanditDB.user_id == user_id) + .where(MultiArmedBanditDB.workspace_id == workspace_id) .where(MultiArmedBanditDB.experiment_id == experiment_id) ) @@ -208,7 +217,10 @@ async def get_mab_by_id( async def delete_mab_by_id( - experiment_id: int, user_id: int, asession: AsyncSession + experiment_id: int, + user_id: int, + workspace_id: int, + asession: AsyncSession ) -> None: """ Delete the experiment by id. @@ -243,6 +255,7 @@ async def delete_mab_by_id( MultiArmedBanditDB.experiment_id == experiment_id, MultiArmedBanditDB.experiment_id == ExperimentBaseDB.experiment_id, MultiArmedBanditDB.user_id == user_id, + MultiArmedBanditDB.workspace_id == workspace_id, ) ) ) diff --git a/backend/app/mab/routers.py b/backend/app/mab/routers.py index 7fd838f..fd5af38 100644 --- a/backend/app/mab/routers.py +++ b/backend/app/mab/routers.py @@ -5,6 +5,9 @@ from fastapi.exceptions import HTTPException from sqlalchemy.ext.asyncio import AsyncSession +from ..workspaces.models import get_user_default_workspace, get_user_role_in_workspace +from ..workspaces.schemas import UserRoles + from ..auth.dependencies import authenticate_key, get_verified_user from ..database import get_async_session from ..models import get_notifications_from_db, save_notifications_to_db @@ -38,9 +41,27 @@ async def create_mab( asession: AsyncSession = Depends(get_async_session), ) -> MultiArmedBanditResponse: """ - Create a new experiment. + Create a new experiment in the user's current workspace. """ - mab = await save_mab_to_db(experiment, user_db.user_id, asession) + workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) + + user_role = await get_user_role_in_workspace( + asession=asession, user_db=user_db, workspace_db=workspace_db + ) + + if user_role != UserRoles.ADMIN: + raise HTTPException( + status_code=403, + detail="Only workspace administrators can create experiments.", + ) + + mab = await save_mab_to_db( + experiment, + user_db.user_id, + workspace_db.workspace_id, + asession + ) + notifications = await save_notifications_to_db( experiment_id=mab.experiment_id, user_id=user_db.user_id, @@ -60,9 +81,15 @@ async def get_mabs( asession: AsyncSession = Depends(get_async_session), ) -> list[MultiArmedBanditResponse]: """ - Get details of all experiments. + Get details of all experiments in the user's current workspace. """ - experiments = await get_all_mabs(user_db.user_id, asession) + workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) + + experiments = await get_all_mabs( + user_db.user_id, + workspace_db.workspace_id, + asession + ) all_experiments = [] for exp in experiments: @@ -95,7 +122,14 @@ async def get_mab( """ Get details of experiment with the provided `experiment_id`. """ - experiment = await get_mab_by_id(experiment_id, user_db.user_id, asession) + workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) + + experiment = await get_mab_by_id( + experiment_id, + user_db.user_id, + workspace_db.workspace_id, + asession + ) if experiment is None: raise HTTPException( @@ -123,12 +157,34 @@ async def delete_mab( Delete the experiment with the provided `experiment_id`. """ try: - experiment = await get_mab_by_id(experiment_id, user_db.user_id, asession) + workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) + + user_role = await get_user_role_in_workspace( + asession=asession, user_db=user_db, workspace_db=workspace_db + ) + + if user_role != UserRoles.ADMIN: + raise HTTPException( + status_code=403, + detail="Only workspace administrators can delete experiments.", + ) + + experiment = await get_mab_by_id( + experiment_id, + user_db.user_id, + workspace_db.workspace_id, + asession + ) if experiment is None: raise HTTPException( status_code=404, detail=f"Experiment with id {experiment_id} not found" ) - await delete_mab_by_id(experiment_id, user_db.user_id, asession) + await delete_mab_by_id( + experiment_id, + user_db.user_id, + workspace_db.workspace_id, + asession + ) return {"message": f"Experiment with id {experiment_id} deleted successfully."} except Exception as e: raise HTTPException(status_code=500, detail=f"Error: {e}") from e @@ -143,7 +199,14 @@ async def draw_arm( """ Get which arm to pull next for provided experiment. """ - experiment = await get_mab_by_id(experiment_id, user_db.user_id, asession) + workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) + + experiment = await get_mab_by_id( + experiment_id, + user_db.user_id, + workspace_db.workspace_id, + asession + ) if experiment is None: raise HTTPException( status_code=404, detail=f"Experiment with id {experiment_id} not found" @@ -165,8 +228,14 @@ async def update_arm( Update the arm with the provided `arm_id` for the given `experiment_id` based on the `outcome`. """ - # Get and validate experiment - experiment = await get_mab_by_id(experiment_id, user_db.user_id, asession) + workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) + + experiment = await get_mab_by_id( + experiment_id, + user_db.user_id, + workspace_db.workspace_id, + asession + ) if experiment is None: raise HTTPException( status_code=404, detail=f"Experiment with id {experiment_id} not found" @@ -233,7 +302,14 @@ async def get_outcomes( """ Get the outcomes for the experiment. """ - experiment = await get_mab_by_id(experiment_id, user_db.user_id, asession) + workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) + + experiment = await get_mab_by_id( + experiment_id, + user_db.user_id, + workspace_db.workspace_id, + asession + ) if not experiment: raise HTTPException( status_code=404, detail=f"Experiment with id {experiment_id} not found" diff --git a/backend/app/mab/schemas.py b/backend/app/mab/schemas.py index 18cd4c7..6270319 100644 --- a/backend/app/mab/schemas.py +++ b/backend/app/mab/schemas.py @@ -172,6 +172,7 @@ class MultiArmedBanditResponse(MultiArmedBanditBase): """ experiment_id: int + workspace_id: int arms: list[ArmResponse] notifications: list[NotificationsResponse] created_datetime_utc: datetime diff --git a/backend/app/models.py b/backend/app/models.py index caa1ac0..5314ebb 100644 --- a/backend/app/models.py +++ b/backend/app/models.py @@ -3,7 +3,7 @@ from sqlalchemy import Boolean, DateTime, Enum, ForeignKey, Integer, String, select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from .schemas import EventType, Notifications @@ -29,6 +29,9 @@ class ExperimentBaseDB(Base): user_id: Mapped[int] = mapped_column( Integer, ForeignKey("users.user_id"), nullable=False ) + workspace_id: Mapped[int] = mapped_column( + Integer, ForeignKey("workspace.workspace_id"), nullable=False + ) is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) exp_type: Mapped[str] = mapped_column(String(length=50), nullable=False) prior_type: Mapped[str] = mapped_column(String(length=50), nullable=False) @@ -41,6 +44,8 @@ class ExperimentBaseDB(Base): last_trial_datetime_utc: Mapped[datetime] = mapped_column( DateTime(timezone=True), nullable=True ) + workspace: Mapped["WorkspaceDB"] = relationship("WorkspaceDB", back_populates="experiments") + __mapper_args__ = { "polymorphic_identity": "experiment", "polymorphic_on": "exp_type", diff --git a/backend/app/users/models.py b/backend/app/users/models.py index 4de00f8..6f8bfc2 100644 --- a/backend/app/users/models.py +++ b/backend/app/users/models.py @@ -56,8 +56,7 @@ class UserDB(Base): ) is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) is_verified: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) - - # Relationships for workspaces + user_workspaces: Mapped[list["UserWorkspaceDB"]] = relationship( "UserWorkspaceDB", back_populates="user", diff --git a/backend/app/users/routers.py b/backend/app/users/routers.py index 5980200..90844bb 100644 --- a/backend/app/users/routers.py +++ b/backend/app/users/routers.py @@ -6,6 +6,8 @@ from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession +from ..config import DEFAULT_API_QUOTA, DEFAULT_EXPERIMENTS_QUOTA + from ..auth.dependencies import get_current_user, get_verified_user from ..auth.utils import generate_verification_token from ..database import get_async_session, get_redis @@ -44,9 +46,9 @@ async def create_user( """ try: # Import workspace functionality to avoid circular imports - from ..workspaces.models import create_user_workspace_role, UserRoles + from ..workspaces.models import UserRoles, create_user_workspace_role from ..workspaces.utils import create_workspace - + # Create the user new_api_key = generate_key() user_new = await save_user_to_db( @@ -60,20 +62,20 @@ async def create_user( # Create default workspace for the user default_workspace_name = f"{user_new.username}'s Workspace" workspace_api_key = generate_key() - + workspace_db, _ = await create_workspace( api_daily_quota=DEFAULT_API_QUOTA, asession=asession, - content_quota=DEFAULT_CONTENT_QUOTA, + content_quota=DEFAULT_EXPERIMENTS_QUOTA, user=UserCreate( role=UserRoles.ADMIN, username=user_new.username, workspace_name=default_workspace_name, ), is_default=True, - api_key=workspace_api_key + api_key=workspace_api_key, ) - + # Add user to workspace as admin await create_user_workspace_role( asession=asession, diff --git a/backend/app/users/schemas.py b/backend/app/users/schemas.py index 5bb3294..c2fb6b5 100644 --- a/backend/app/users/schemas.py +++ b/backend/app/users/schemas.py @@ -12,6 +12,9 @@ class UserCreate(BaseModel): username: str experiments_quota: Optional[int] = None api_daily_quota: Optional[int] = None + workspace_name: Optional[str] = None + role: Optional[str] = None + is_default_workspace: Optional[bool] = False model_config = ConfigDict(from_attributes=True) diff --git a/backend/app/workspaces/__init__.py b/backend/app/workspaces/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/workspaces/models.py b/backend/app/workspaces/models.py new file mode 100644 index 0000000..f7b37a2 --- /dev/null +++ b/backend/app/workspaces/models.py @@ -0,0 +1,303 @@ +from datetime import datetime, timezone +from typing import Sequence, TYPE_CHECKING + +import sqlalchemy.sql.functions as func +from sqlalchemy import ( + Boolean, + DateTime, + Enum, + ForeignKey, + Integer, + Row, + String, + case, + exists, + select, + text, + update, +) +from sqlalchemy.exc import NoResultFound +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from ..users.schemas import UserCreate + +from ..models import Base, ExperimentBaseDB +from ..utils import get_key_hash +from .schemas import UserRoles, UserCreateWithCode + +# Use TYPE_CHECKING for type hints without runtime imports +if TYPE_CHECKING: + from ..users.models import UserDB + + +class UserWorkspaceRoleAlreadyExistsError(Exception): + """Exception raised when a user workspace role already exists in the database.""" + + +class UserNotFoundInWorkspaceError(Exception): + """Exception raised when a user is not found in a workspace in the database.""" + + +class WorkspaceDB(Base): + """ORM for managing workspaces. + + A workspace is an isolated virtual environment that contains contents that can be + accessed and modified by users assigned to that workspace. Workspaces must be + unique but can contain duplicated content. Users can be assigned to one more + workspaces, with different roles. In other words, there is a MANY-to-MANY + relationship between users and workspaces. + """ + + __tablename__ = "workspace" + + api_daily_quota: Mapped[int | None] = mapped_column(Integer, nullable=True) + api_key_first_characters: Mapped[str] = mapped_column(String(5), nullable=True) + api_key_updated_datetime_utc: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=True + ) + content_quota: Mapped[int | None] = mapped_column(Integer, nullable=True) + created_datetime_utc: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False + ) + hashed_api_key: Mapped[str] = mapped_column(String(96), nullable=True, unique=True) + updated_datetime_utc: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False + ) + user_workspaces: Mapped[list["UserWorkspaceDB"]] = relationship( + "UserWorkspaceDB", + back_populates="workspace", + cascade="all, delete-orphan", + passive_deletes=True, + ) + users: Mapped[list["UserDB"]] = relationship( + "UserDB", back_populates="workspaces", secondary="user_workspace", viewonly=True + ) + workspace_id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False) + workspace_name: Mapped[str] = mapped_column(String, nullable=False, unique=True) + is_default: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + experiments: Mapped[list["ExperimentBaseDB"]] = relationship( + "ExperimentBaseDB", back_populates="workspace", cascade="all, delete-orphan" + ) + + def __repr__(self) -> str: + """Define the string representation for the `WorkspaceDB` class.""" + return f"" + + +class UserWorkspaceDB(Base): + """ORM for managing user in workspaces.""" + + __tablename__ = "user_workspace" + + created_datetime_utc: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False + ) + default_workspace: Mapped[bool] = mapped_column( + Boolean, + nullable=False, + server_default=text("false"), # Ensures existing rows default to false + ) + updated_datetime_utc: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False + ) + user: Mapped["UserDB"] = relationship("UserDB", back_populates="user_workspaces") + user_id: Mapped[int] = mapped_column( + Integer, ForeignKey("users.user_id", ondelete="CASCADE"), primary_key=True + ) + user_role: Mapped[UserRoles] = mapped_column( + Enum(UserRoles, native_enum=False), nullable=False + ) + workspace: Mapped["WorkspaceDB"] = relationship( + "WorkspaceDB", back_populates="user_workspaces" + ) + workspace_id: Mapped[int] = mapped_column( + Integer, + ForeignKey("workspace.workspace_id", ondelete="CASCADE"), + primary_key=True, + ) + + def __repr__(self) -> str: + """Define the string representation for the `UserWorkspaceDB` class.""" + return f"." + + +async def check_if_user_has_default_workspace( + *, asession: AsyncSession, user_db: "UserDB" +) -> bool | None: + """Check if a user has an assigned default workspace.""" + stmt = select( + exists().where( + UserWorkspaceDB.user_id == user_db.user_id, + UserWorkspaceDB.default_workspace.is_(True), + ) + ) + result = await asession.execute(stmt) + return result.scalar() + + +async def get_user_default_workspace( + *, asession: AsyncSession, user_db: "UserDB" +) -> WorkspaceDB: + """Retrieve the default workspace for a given user.""" + stmt = ( + select(WorkspaceDB) + .join(UserWorkspaceDB, UserWorkspaceDB.workspace_id == WorkspaceDB.workspace_id) + .where( + UserWorkspaceDB.user_id == user_db.user_id, + UserWorkspaceDB.default_workspace.is_(True), + ) + .limit(1) + ) + + result = await asession.execute(stmt) + default_workspace_db = result.scalar_one() + return default_workspace_db + + +async def get_user_workspaces( + *, asession: AsyncSession, user_db: "UserDB" +) -> Sequence[WorkspaceDB]: + """Retrieve all workspaces a user belongs to.""" + stmt = ( + select(WorkspaceDB) + .join(UserWorkspaceDB, UserWorkspaceDB.workspace_id == WorkspaceDB.workspace_id) + .where(UserWorkspaceDB.user_id == user_db.user_id) + ) + result = await asession.execute(stmt) + return result.scalars().all() + + +async def get_user_role_in_workspace( + *, asession: AsyncSession, user_db: "UserDB", workspace_db: WorkspaceDB +) -> UserRoles | None: + """Retrieve the role of a user in a workspace.""" + stmt = select(UserWorkspaceDB.user_role).where( + UserWorkspaceDB.user_id == user_db.user_id, + UserWorkspaceDB.workspace_id == workspace_db.workspace_id, + ) + result = await asession.execute(stmt) + user_role = result.scalar_one_or_none() + return user_role + + +async def update_user_default_workspace( + *, asession: AsyncSession, user_db: "UserDB", workspace_db: WorkspaceDB +) -> None: + """Update the default workspace for the user to the specified workspace.""" + stmt = ( + update(UserWorkspaceDB) + .where(UserWorkspaceDB.user_id == user_db.user_id) + .values( + default_workspace=case( + (UserWorkspaceDB.workspace_id == workspace_db.workspace_id, True), + else_=False, + ), + updated_datetime_utc=datetime.now(timezone.utc) + ) + ) + + await asession.execute(stmt) + await asession.commit() + + +async def create_user_workspace_role( + *, + asession: AsyncSession, + is_default_workspace: bool = False, + user_db: "UserDB", + user_role: UserRoles, + workspace_db: WorkspaceDB, +) -> UserWorkspaceDB: + """Create a user in a workspace with the specified role.""" + existing_user_role = await get_user_role_in_workspace( + asession=asession, user_db=user_db, workspace_db=workspace_db + ) + + if existing_user_role is not None: + raise UserWorkspaceRoleAlreadyExistsError( + f"User '{user_db.username}' with role '{user_role}' in workspace " + f"{workspace_db.workspace_name} already exists." + ) + + user_workspace_role_db = UserWorkspaceDB( + created_datetime_utc=datetime.now(timezone.utc), + default_workspace=is_default_workspace, + updated_datetime_utc=datetime.now(timezone.utc), + user_id=user_db.user_id, + user_role=user_role, + workspace_id=workspace_db.workspace_id, + ) + + asession.add(user_workspace_role_db) + await asession.commit() + await asession.refresh(user_workspace_role_db) + + return user_workspace_role_db + + +async def get_workspaces_by_user_role( + *, asession: AsyncSession, user_db: "UserDB", user_role: UserRoles +) -> Sequence[WorkspaceDB]: + """Retrieve all workspaces for the user with the specified role.""" + stmt = ( + select(WorkspaceDB) + .join(UserWorkspaceDB, WorkspaceDB.workspace_id == UserWorkspaceDB.workspace_id) + .where(UserWorkspaceDB.user_id == user_db.user_id) + .where(UserWorkspaceDB.user_role == user_role) + ) + result = await asession.execute(stmt) + return result.scalars().all() + + +async def user_has_admin_role_in_any_workspace( + *, asession: AsyncSession, user_db: "UserDB" +) -> bool: + """Check if a user has the ADMIN role in any workspace.""" + stmt = ( + select(UserWorkspaceDB.user_id) + .where( + UserWorkspaceDB.user_id == user_db.user_id, + UserWorkspaceDB.user_role == UserRoles.ADMIN, + ) + .limit(1) + ) + result = await asession.execute(stmt) + return result.scalar_one_or_none() is not None + + +async def add_existing_user_to_workspace( + *, + asession: AsyncSession, + user: UserCreate, + workspace_db: WorkspaceDB, +) -> UserCreateWithCode: + """Add an existing user to a workspace.""" + # Import here to avoid circular imports + from ..users.models import get_user_by_username + + assert user.role is not None + user.is_default_workspace = user.is_default_workspace or False + + user_db = await get_user_by_username(username=user.username, asession=asession) + + if user.is_default_workspace: + await update_user_default_workspace( + asession=asession, user_db=user_db, workspace_db=workspace_db + ) + + _ = await create_user_workspace_role( + asession=asession, + is_default_workspace=user.is_default_workspace, + user_db=user_db, + user_role=user.role, + workspace_db=workspace_db, + ) + + return UserCreateWithCode( + is_default_workspace=user.is_default_workspace, + recovery_codes=[], # We don't use recovery codes in your implementation + role=user.role, + username=user_db.username, + workspace_name=workspace_db.workspace_name, + ) diff --git a/backend/app/workspaces/routers.py b/backend/app/workspaces/routers.py new file mode 100644 index 0000000..d36fb1a --- /dev/null +++ b/backend/app/workspaces/routers.py @@ -0,0 +1,479 @@ +from typing import Annotated, List + +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, status +from fastapi.exceptions import HTTPException +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.ext.asyncio import AsyncSession +from redis.asyncio import Redis + +from ..auth.dependencies import ( + create_access_token, + get_current_user, +) +from ..auth.schemas import AuthenticationDetails +from ..config import DEFAULT_API_QUOTA, DEFAULT_EXPERIMENTS_QUOTA +from ..database import get_async_session, get_redis +from ..email import EmailService +from ..users.models import ( + UserDB, + UserNotFoundError, + save_user_to_db, + get_user_by_username, +) +from ..users.schemas import UserCreate +from ..utils import generate_key, setup_logger +from .models import ( + add_existing_user_to_workspace, + check_if_user_has_default_workspace, + get_user_default_workspace, + get_user_role_in_workspace, + get_user_workspaces, + get_workspaces_by_user_role, + update_user_default_workspace, + user_has_admin_role_in_any_workspace, +) +from .schemas import ( + UserRoles, + WorkspaceCreate, + WorkspaceInvite, + WorkspaceInviteResponse, + WorkspaceKeyResponse, + WorkspaceRetrieve, + WorkspaceSwitch, + WorkspaceUpdate, +) +from .utils import ( + WorkspaceNotFoundError, + create_workspace, + get_workspace_by_workspace_id, + get_workspace_by_workspace_name, + is_workspace_name_valid, + update_workspace_api_key, + update_workspace_name_and_quotas, +) + +TAG_METADATA = { + "name": "Workspace", + "description": "_Requires user login._ Only administrator user has access to these " + "endpoints and only for the workspaces that they are assigned to.", +} + +router = APIRouter(prefix="/workspace", tags=["Workspace"]) +logger = setup_logger() +email_service = EmailService() + + +@router.post("/", response_model=WorkspaceRetrieve) +async def create_workspace_endpoint( + calling_user_db: Annotated[UserDB, Depends(get_current_user)], + workspace: WorkspaceCreate, + asession: AsyncSession = Depends(get_async_session), +) -> WorkspaceRetrieve: + """Create a new workspace. Workspaces can only be created by authenticated users.""" + if not await check_if_user_has_default_workspace( + asession=asession, user_db=calling_user_db + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="User must be assigned to a workspace first before creating new workspaces.", + ) + + # Check if workspace name is valid + if not await is_workspace_name_valid( + asession=asession, workspace_name=workspace.workspace_name + ): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Workspace with name '{workspace.workspace_name}' already exists.", + ) + + # Create new workspace + api_key = generate_key() + workspace_db, is_new_workspace = await create_workspace( + api_daily_quota=workspace.api_daily_quota or DEFAULT_API_QUOTA, + asession=asession, + content_quota=workspace.content_quota or DEFAULT_EXPERIMENTS_QUOTA, + user=UserCreate( + role=UserRoles.ADMIN, + username=calling_user_db.username, + workspace_name=workspace.workspace_name, + ), + api_key=api_key, + ) + + if is_new_workspace: + # Add the calling user as an admin to the new workspace + await add_existing_user_to_workspace( + asession=asession, + user=UserCreate( + is_default_workspace=False, # Don't make it default automatically + role=UserRoles.ADMIN, + username=calling_user_db.username, + workspace_name=workspace_db.workspace_name, + ), + workspace_db=workspace_db, + ) + + return WorkspaceRetrieve( + api_daily_quota=workspace_db.api_daily_quota, + api_key_first_characters=workspace_db.api_key_first_characters, + api_key_updated_datetime_utc=workspace_db.api_key_updated_datetime_utc, + content_quota=workspace_db.content_quota, + created_datetime_utc=workspace_db.created_datetime_utc, + updated_datetime_utc=workspace_db.updated_datetime_utc, + workspace_id=workspace_db.workspace_id, + workspace_name=workspace_db.workspace_name, + is_default=workspace_db.is_default, + ) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Workspace already exists.", + ) + + +@router.get("/", response_model=List[WorkspaceRetrieve]) +async def retrieve_all_workspaces( + calling_user_db: Annotated[UserDB, Depends(get_current_user)], + asession: AsyncSession = Depends(get_async_session), +) -> List[WorkspaceRetrieve]: + """Return a list of all workspaces the user has access to.""" + user_workspaces = await get_user_workspaces(asession=asession, user_db=calling_user_db) + + return [ + WorkspaceRetrieve( + api_daily_quota=workspace_db.api_daily_quota, + api_key_first_characters=workspace_db.api_key_first_characters, + api_key_updated_datetime_utc=workspace_db.api_key_updated_datetime_utc, + content_quota=workspace_db.content_quota, + created_datetime_utc=workspace_db.created_datetime_utc, + updated_datetime_utc=workspace_db.updated_datetime_utc, + workspace_id=workspace_db.workspace_id, + workspace_name=workspace_db.workspace_name, + is_default=workspace_db.is_default, + ) + for workspace_db in user_workspaces + ] + + +@router.get("/current", response_model=WorkspaceRetrieve) +async def get_current_workspace( + calling_user_db: Annotated[UserDB, Depends(get_current_user)], + asession: AsyncSession = Depends(get_async_session), +) -> WorkspaceRetrieve: + """Return the current default workspace for the user.""" + try: + workspace_db = await get_user_default_workspace( + asession=asession, user_db=calling_user_db + ) + + return WorkspaceRetrieve( + api_daily_quota=workspace_db.api_daily_quota, + api_key_first_characters=workspace_db.api_key_first_characters, + api_key_updated_datetime_utc=workspace_db.api_key_updated_datetime_utc, + content_quota=workspace_db.content_quota, + created_datetime_utc=workspace_db.created_datetime_utc, + updated_datetime_utc=workspace_db.updated_datetime_utc, + workspace_id=workspace_db.workspace_id, + workspace_name=workspace_db.workspace_name, + is_default=workspace_db.is_default, + ) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="No default workspace found for the user.", + ) from e + + +@router.post("/switch", response_model=AuthenticationDetails) +async def switch_workspace( + calling_user_db: Annotated[UserDB, Depends(get_current_user)], + workspace_switch: WorkspaceSwitch, + asession: AsyncSession = Depends(get_async_session), +) -> AuthenticationDetails: + """Switch to a different workspace.""" + # Find the workspace + try: + workspace_db = await get_workspace_by_workspace_name( + asession=asession, workspace_name=workspace_switch.workspace_name + ) + except WorkspaceNotFoundError as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Workspace '{workspace_switch.workspace_name}' not found.", + ) from e + + # Check if user belongs to this workspace + user_role = await get_user_role_in_workspace( + asession=asession, user_db=calling_user_db, workspace_db=workspace_db + ) + + if user_role is None: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"User does not have access to workspace '{workspace_switch.workspace_name}'.", + ) + + # Set this workspace as the default for the user + await update_user_default_workspace( + asession=asession, user_db=calling_user_db, workspace_db=workspace_db + ) + + # Create a new token with the updated workspace information + return AuthenticationDetails( + access_level="fullaccess", + access_token=create_access_token(calling_user_db.username), + token_type="bearer", + username=calling_user_db.username, + is_verified=calling_user_db.is_verified, + api_key_first_characters=calling_user_db.api_key_first_characters, + ) + + +@router.put("/rotate-key", response_model=WorkspaceKeyResponse) +async def rotate_workspace_api_key( + calling_user_db: Annotated[UserDB, Depends(get_current_user)], + asession: AsyncSession = Depends(get_async_session), +) -> WorkspaceKeyResponse: + """Generate a new API key for the current workspace.""" + try: + # Get the user's default workspace + workspace_db = await get_user_default_workspace( + asession=asession, user_db=calling_user_db + ) + + # Verify user is an admin in this workspace + user_role = await get_user_role_in_workspace( + asession=asession, user_db=calling_user_db, workspace_db=workspace_db + ) + + if user_role != UserRoles.ADMIN: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Only workspace administrators can rotate API keys.", + ) + + # Generate and update the API key + new_api_key = generate_key() + asession.add(workspace_db) + workspace_db = await update_workspace_api_key( + asession=asession, new_api_key=new_api_key, workspace_db=workspace_db + ) + + return WorkspaceKeyResponse( + new_api_key=new_api_key, + workspace_name=workspace_db.workspace_name, + ) + except Exception as e: + logger.error(f"Error rotating workspace API key: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Error rotating workspace API key.", + ) from e + + +@router.get("/{workspace_id}", response_model=WorkspaceRetrieve) +async def retrieve_workspace_by_workspace_id( + calling_user_db: Annotated[UserDB, Depends(get_current_user)], + workspace_id: int, + asession: AsyncSession = Depends(get_async_session), +) -> WorkspaceRetrieve: + """Retrieve a workspace by ID.""" + try: + # Get the workspace + workspace_db = await get_workspace_by_workspace_id( + asession=asession, workspace_id=workspace_id + ) + + # Check if user has access to this workspace + user_role = await get_user_role_in_workspace( + asession=asession, user_db=calling_user_db, workspace_db=workspace_db + ) + + if user_role is None: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"User does not have access to workspace with ID {workspace_id}.", + ) + + return WorkspaceRetrieve( + api_daily_quota=workspace_db.api_daily_quota, + api_key_first_characters=workspace_db.api_key_first_characters, + api_key_updated_datetime_utc=workspace_db.api_key_updated_datetime_utc, + content_quota=workspace_db.content_quota, + created_datetime_utc=workspace_db.created_datetime_utc, + updated_datetime_utc=workspace_db.updated_datetime_utc, + workspace_id=workspace_db.workspace_id, + workspace_name=workspace_db.workspace_name, + is_default=workspace_db.is_default, + ) + except WorkspaceNotFoundError as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Workspace with ID {workspace_id} not found.", + ) from e + + +@router.put("/{workspace_id}", response_model=WorkspaceRetrieve) +async def update_workspace_endpoint( + workspace_id: int, + workspace_update: WorkspaceUpdate, + calling_user_db: Annotated[UserDB, Depends(get_current_user)], + asession: AsyncSession = Depends(get_async_session), +) -> WorkspaceRetrieve: + """Update workspace details (name, quotas).""" + try: + # Get the workspace + workspace_db = await get_workspace_by_workspace_id( + asession=asession, workspace_id=workspace_id + ) + + # Verify user is an admin in this workspace + user_role = await get_user_role_in_workspace( + asession=asession, user_db=calling_user_db, workspace_db=workspace_db + ) + + if user_role != UserRoles.ADMIN: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Only workspace administrators can update workspace details.", + ) + + # Check if the new workspace name is valid + if workspace_update.workspace_name and workspace_update.workspace_name != workspace_db.workspace_name: + if not await is_workspace_name_valid( + asession=asession, workspace_name=workspace_update.workspace_name + ): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Workspace with name '{workspace_update.workspace_name}' already exists.", + ) + + # Update the workspace + asession.add(workspace_db) + updated_workspace = await update_workspace_name_and_quotas( + asession=asession, workspace=workspace_update, workspace_db=workspace_db + ) + + return WorkspaceRetrieve( + api_daily_quota=updated_workspace.api_daily_quota, + api_key_first_characters=updated_workspace.api_key_first_characters, + api_key_updated_datetime_utc=updated_workspace.api_key_updated_datetime_utc, + content_quota=updated_workspace.content_quota, + created_datetime_utc=updated_workspace.created_datetime_utc, + updated_datetime_utc=updated_workspace.updated_datetime_utc, + workspace_id=updated_workspace.workspace_id, + workspace_name=updated_workspace.workspace_name, + is_default=updated_workspace.is_default, + ) + except WorkspaceNotFoundError as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Workspace with ID {workspace_id} not found.", + ) from e + except SQLAlchemyError as e: + logger.error(f"Database error when updating workspace: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Database error when updating workspace.", + ) from e + + +@router.post("/invite", response_model=WorkspaceInviteResponse) +async def invite_user_to_workspace( + calling_user_db: Annotated[UserDB, Depends(get_current_user)], + invite: WorkspaceInvite, + background_tasks: BackgroundTasks, + asession: AsyncSession = Depends(get_async_session), + redis: Redis = Depends(get_redis), +) -> WorkspaceInviteResponse: + """Invite a user to join a workspace.""" + try: + # Get the workspace + workspace_db = await get_workspace_by_workspace_name( + asession=asession, workspace_name=invite.workspace_name + ) + + # Check if it's a default workspace (users can't invite others to default workspaces) + if workspace_db.is_default: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Users cannot be invited to default workspaces.", + ) + + # Verify caller is an admin in this workspace + user_role = await get_user_role_in_workspace( + asession=asession, user_db=calling_user_db, workspace_db=workspace_db + ) + + if user_role != UserRoles.ADMIN: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Only workspace administrators can invite users.", + ) + + # Check if the invited user exists + user_exists = False + try: + invited_user = await get_user_by_username( + username=invite.email, asession=asession + ) + user_exists = True + + # Add existing user to workspace + await add_existing_user_to_workspace( + asession=asession, + user=UserCreate( + role=invite.role, + username=invite.email, + workspace_name=invite.workspace_name, + ), + workspace_db=workspace_db, + ) + + # Send invitation email to existing user + background_tasks.add_task( + email_service.send_workspace_invitation_email, + invite.email, + invite.email, + calling_user_db.username, + workspace_db.workspace_name, + True, # user exists + ) + + return WorkspaceInviteResponse( + message=f"User {invite.email} has been added to workspace '{workspace_db.workspace_name}'.", + email=invite.email, + workspace_name=workspace_db.workspace_name, + user_exists=True, + ) + + except UserNotFoundError: + # User doesn't exist, send invitation to create account + background_tasks.add_task( + email_service.send_workspace_invitation_email, + invite.email, + invite.email, + calling_user_db.username, + workspace_db.workspace_name, + False, # user doesn't exist + ) + + return WorkspaceInviteResponse( + message=f"Invitation sent to {invite.email} to join workspace '{workspace_db.workspace_name}'.", + email=invite.email, + workspace_name=workspace_db.workspace_name, + user_exists=False, + ) + + except WorkspaceNotFoundError as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Workspace '{invite.workspace_name}' not found.", + ) from e + except Exception as e: + logger.error(f"Error inviting user to workspace: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Error inviting user to workspace.", + ) from e diff --git a/backend/app/workspaces/schemas.py b/backend/app/workspaces/schemas.py new file mode 100644 index 0000000..5df8378 --- /dev/null +++ b/backend/app/workspaces/schemas.py @@ -0,0 +1,111 @@ +from datetime import datetime +from enum import Enum +from typing import Optional, List + +from pydantic import BaseModel, ConfigDict, EmailStr + + +class UserRoles(str, Enum): + """Enumeration for user roles. + + There are 2 different types of users: + + 1. (Read-Only) Users: These users are assigned to workspaces and can only read the + contents within their assigned workspaces. They cannot modify existing + contents or add new contents to their workspaces, add or delete users from + their workspaces, or add or delete workspaces. + 2. Admin Users: These users are assigned to workspaces and can read and modify the + contents within their assigned workspaces. They can also add or delete users + from their own workspaces and can also add new workspaces or delete their own + workspaces. Admin users have no control over workspaces that they are not + assigned to. + """ + + ADMIN = "admin" + READ_ONLY = "read_only" + + +class WorkspaceCreate(BaseModel): + """Pydantic model for workspace creation.""" + api_daily_quota: int | None = None + content_quota: int | None = None + workspace_name: str + + model_config = ConfigDict(from_attributes=True) + + +class WorkspaceKeyResponse(BaseModel): + """Pydantic model for updating workspace API key.""" + new_api_key: str + workspace_name: str + + model_config = ConfigDict(from_attributes=True) + + +class WorkspaceRetrieve(BaseModel): + """Pydantic model for workspace retrieval.""" + api_daily_quota: Optional[int] = None + api_key_first_characters: Optional[str] = None + api_key_updated_datetime_utc: Optional[datetime] = None + content_quota: Optional[int] = None + created_datetime_utc: datetime + updated_datetime_utc: Optional[datetime] = None + workspace_id: int + workspace_name: str + is_default: bool = False + + model_config = ConfigDict(from_attributes=True) + + +class WorkspaceSwitch(BaseModel): + """Pydantic model for switching workspaces.""" + workspace_name: str + + model_config = ConfigDict(from_attributes=True) + + +class WorkspaceUpdate(BaseModel): + """Pydantic model for workspace updates.""" + workspace_name: Optional[str] = None + + model_config = ConfigDict(from_attributes=True) + + +class UserWorkspace(BaseModel): + """Pydantic model for user workspace information.""" + user_role: UserRoles + workspace_id: int + workspace_name: str + is_default: bool = False + + model_config = ConfigDict(from_attributes=True) + + +class UserCreateWithCode(BaseModel): + """Pydantic model for user creation with recovery codes.""" + is_default_workspace: bool = False + recovery_codes: List[str] = [] + role: UserRoles + username: str + workspace_name: str + + model_config = ConfigDict(from_attributes=True) + + +class WorkspaceInvite(BaseModel): + """Pydantic model for inviting users to a workspace.""" + email: EmailStr + role: UserRoles = UserRoles.READ_ONLY + workspace_name: str + + model_config = ConfigDict(from_attributes=True) + + +class WorkspaceInviteResponse(BaseModel): + """Pydantic model for workspace invite response.""" + message: str + email: EmailStr + workspace_name: str + user_exists: bool + + model_config = ConfigDict(from_attributes=True) diff --git a/backend/app/workspaces/utils.py b/backend/app/workspaces/utils.py new file mode 100644 index 0000000..9ae7d83 --- /dev/null +++ b/backend/app/workspaces/utils.py @@ -0,0 +1,141 @@ +from datetime import datetime, timezone +from typing import Optional + +from sqlalchemy import select +from sqlalchemy.exc import NoResultFound +from sqlalchemy.ext.asyncio import AsyncSession + +from ..users.schemas import UserCreate +from ..utils import get_key_hash +from .models import WorkspaceDB +from .schemas import WorkspaceUpdate + + +class WorkspaceNotFoundError(Exception): + """Exception raised when a workspace is not found in the database.""" + + +async def check_if_workspaces_exist(*, asession: AsyncSession) -> bool: + """Check if workspaces exist in the `WorkspaceDB` database.""" + stmt = select(WorkspaceDB.workspace_id).limit(1) + result = await asession.scalars(stmt) + return result.first() is not None + + +async def create_workspace( + *, + api_daily_quota: Optional[int] = None, + asession: AsyncSession, + content_quota: Optional[int] = None, + user: UserCreate, + is_default: bool = False, + api_key: Optional[str] = None, +) -> tuple[WorkspaceDB, bool]: + """Create a workspace in the `WorkspaceDB` database. If the workspace already + exists, then it is returned. + """ + assert api_daily_quota is None or api_daily_quota >= 0 + assert content_quota is None or content_quota >= 0 + + result = await asession.execute( + select(WorkspaceDB).where(WorkspaceDB.workspace_name == user.workspace_name) + ) + workspace_db = result.scalar_one_or_none() + new_workspace = False + + if workspace_db is None: + new_workspace = True + workspace_db = WorkspaceDB( + api_daily_quota=api_daily_quota, + content_quota=content_quota, + created_datetime_utc=datetime.now(timezone.utc), + updated_datetime_utc=datetime.now(timezone.utc), + workspace_name=user.workspace_name, + is_default=is_default + ) + + if api_key: + workspace_db.hashed_api_key = get_key_hash(api_key) + workspace_db.api_key_first_characters = api_key[:5] + workspace_db.api_key_updated_datetime_utc = datetime.now(timezone.utc) + + asession.add(workspace_db) + await asession.commit() + await asession.refresh(workspace_db) + + return workspace_db, new_workspace + + +async def get_workspace_by_workspace_id( + *, asession: AsyncSession, workspace_id: int +) -> WorkspaceDB: + """Retrieve a workspace by workspace ID.""" + stmt = select(WorkspaceDB).where(WorkspaceDB.workspace_id == workspace_id) + result = await asession.execute(stmt) + try: + workspace_db = result.scalar_one() + return workspace_db + except NoResultFound as err: + raise WorkspaceNotFoundError( + f"Workspace with ID {workspace_id} does not exist." + ) from err + + +async def get_workspace_by_workspace_name( + *, asession: AsyncSession, workspace_name: str +) -> WorkspaceDB: + """Retrieve a workspace by workspace name.""" + stmt = select(WorkspaceDB).where(WorkspaceDB.workspace_name == workspace_name) + result = await asession.execute(stmt) + try: + workspace_db = result.scalar_one() + return workspace_db + except NoResultFound as err: + raise WorkspaceNotFoundError( + f"Workspace with name {workspace_name} does not exist." + ) from err + + +async def is_workspace_name_valid( + *, asession: AsyncSession, workspace_name: str +) -> bool: + """Check if a workspace name is valid. A workspace name is valid if it doesn't + already exist in the database. + """ + stmt = select(WorkspaceDB).where(WorkspaceDB.workspace_name == workspace_name) + result = await asession.execute(stmt) + try: + result.scalar_one() + return False + except NoResultFound: + return True + + +async def update_workspace_api_key( + *, asession: AsyncSession, new_api_key: str, workspace_db: WorkspaceDB +) -> WorkspaceDB: + """Update a workspace API key.""" + workspace_db.hashed_api_key = get_key_hash(key=new_api_key) + workspace_db.api_key_first_characters = new_api_key[:5] + workspace_db.api_key_updated_datetime_utc = datetime.now(timezone.utc) + workspace_db.updated_datetime_utc = datetime.now(timezone.utc) + + await asession.commit() + await asession.refresh(workspace_db) + + return workspace_db + + +async def update_workspace_name_and_quotas( + *, asession: AsyncSession, workspace: WorkspaceUpdate, workspace_db: WorkspaceDB +) -> WorkspaceDB: + """Update workspace name""" + if workspace.workspace_name is not None: + workspace_db.workspace_name = workspace.workspace_name + + workspace_db.updated_datetime_utc = datetime.now(timezone.utc) + + await asession.commit() + await asession.refresh(workspace_db) + + return workspace_db diff --git a/backend/migrations/versions/949c9fc0461d_workspace_relationship.py b/backend/migrations/versions/949c9fc0461d_workspace_relationship.py new file mode 100644 index 0000000..020cb27 --- /dev/null +++ b/backend/migrations/versions/949c9fc0461d_workspace_relationship.py @@ -0,0 +1,32 @@ +"""Workspace relationship + +Revision ID: 949c9fc0461d +Revises: 977e7e73ce06 +Create Date: 2025-04-21 21:20:56.282928 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '949c9fc0461d' +down_revision: Union[str, None] = '977e7e73ce06' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('experiments_base', sa.Column('workspace_id', sa.Integer(), nullable=False)) + op.create_foreign_key(None, 'experiments_base', 'workspace', ['workspace_id'], ['workspace_id']) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, 'experiments_base', type_='foreignkey') + op.drop_column('experiments_base', 'workspace_id') + # ### end Alembic commands ### diff --git a/backend/migrations/versions/977e7e73ce06_workspace_model.py b/backend/migrations/versions/977e7e73ce06_workspace_model.py new file mode 100644 index 0000000..81ff8f5 --- /dev/null +++ b/backend/migrations/versions/977e7e73ce06_workspace_model.py @@ -0,0 +1,56 @@ +"""Workspace model + +Revision ID: 977e7e73ce06 +Revises: ba1bf29910f5 +Create Date: 2025-04-20 20:17:32.839934 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '977e7e73ce06' +down_revision: Union[str, None] = 'ba1bf29910f5' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('workspace', + sa.Column('api_daily_quota', sa.Integer(), nullable=True), + sa.Column('api_key_first_characters', sa.String(length=5), nullable=True), + sa.Column('api_key_updated_datetime_utc', sa.DateTime(timezone=True), nullable=True), + sa.Column('content_quota', sa.Integer(), nullable=True), + sa.Column('created_datetime_utc', sa.DateTime(timezone=True), nullable=False), + sa.Column('hashed_api_key', sa.String(length=96), nullable=True), + sa.Column('updated_datetime_utc', sa.DateTime(timezone=True), nullable=False), + sa.Column('workspace_id', sa.Integer(), nullable=False), + sa.Column('workspace_name', sa.String(), nullable=False), + sa.Column('is_default', sa.Boolean(), nullable=False), + sa.PrimaryKeyConstraint('workspace_id'), + sa.UniqueConstraint('hashed_api_key'), + sa.UniqueConstraint('workspace_name') + ) + op.create_table('user_workspace', + sa.Column('created_datetime_utc', sa.DateTime(timezone=True), nullable=False), + sa.Column('default_workspace', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('updated_datetime_utc', sa.DateTime(timezone=True), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=False), + sa.Column('user_role', sa.Enum('ADMIN', 'READ_ONLY', name='userroles', native_enum=False), nullable=False), + sa.Column('workspace_id', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['user_id'], ['users.user_id'], ondelete='CASCADE'), + sa.ForeignKeyConstraint(['workspace_id'], ['workspace.workspace_id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('user_id', 'workspace_id') + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('user_workspace') + op.drop_table('workspace') + # ### end Alembic commands ### diff --git a/frontend/src/app/(protected)/workspace/create/page.tsx b/frontend/src/app/(protected)/workspace/create/page.tsx new file mode 100644 index 0000000..2d57b60 --- /dev/null +++ b/frontend/src/app/(protected)/workspace/create/page.tsx @@ -0,0 +1,131 @@ +"use client"; + +import { zodResolver } from "@hookform/resolvers/zod"; +import { useForm } from "react-hook-form"; +import { z } from "zod"; +import { Button } from "@/components/catalyst/button"; +import { Input } from "@/components/catalyst/input"; +import { useAuth } from "@/utils/auth"; +import { apiCalls } from "@/utils/api"; +import { useToast } from "@/hooks/use-toast"; +import { useRouter } from "next/navigation"; +import { useState } from "react"; +import { + Fieldset, + Field, + FieldGroup, + Label, + Description, +} from "@/components/catalyst/fieldset"; +import { Card, CardContent, CardHeader, CardTitle, CardDescription } from "@/components/ui/card"; +import { BuildingOfficeIcon } from "@heroicons/react/20/solid"; + +const formSchema = z.object({ + workspace_name: z.string().min(3, { + message: "Workspace name must be at least 3 characters", + }), +}); + +type FormValues = z.infer; + +export default function CreateWorkspacePage() { + const { token, switchWorkspace } = useAuth(); + const { toast } = useToast(); + const router = useRouter(); + const [isSubmitting, setIsSubmitting] = useState(false); + + const form = useForm({ + resolver: zodResolver(formSchema), + defaultValues: { + workspace_name: "", + }, + }); + + const onSubmit = async (data: FormValues) => { + if (!token) { + toast({ + title: "Error", + description: "You must be logged in to create a workspace", + variant: "destructive", + }); + return; + } + + setIsSubmitting(true); + try { + const response = await apiCalls.createWorkspace(token, data); + + await switchWorkspace(response.workspace_name); + + toast({ + title: "Success", + description: `Workspace "${response.workspace_name}" created and activated!`, + }); + + router.push("/workspace"); + } catch (error: any) { + toast({ + title: "Error", + description: error.message || "Failed to create workspace", + variant: "destructive", + }); + } finally { + setIsSubmitting(false); + } + }; + + return ( +
+ + +
+ +
+ Create new workspace + + Create a new workspace to organize your experiments and team members + +
+
+
+ +
+
+ + + + + {form.formState.errors.workspace_name && ( +

+ {form.formState.errors.workspace_name.message} +

+ )} + Choose a descriptive name for your new workspace. +
+
+ +
+ + +
+
+
+
+
+
+ ); +} diff --git a/frontend/src/app/(protected)/workspace/invite/page.tsx b/frontend/src/app/(protected)/workspace/invite/page.tsx new file mode 100644 index 0000000..f610ffc --- /dev/null +++ b/frontend/src/app/(protected)/workspace/invite/page.tsx @@ -0,0 +1,193 @@ +"use client"; + +import { useState } from "react"; +import { useAuth } from "@/utils/auth"; +import { apiCalls } from "@/utils/api"; +import { useToast } from "@/hooks/use-toast"; +import { z } from "zod"; +import { zodResolver } from "@hookform/resolvers/zod"; +import { useForm } from "react-hook-form"; +import { Button } from "@/components/catalyst/button"; +import { Heading } from "@/components/catalyst/heading"; +import { Input } from "@/components/catalyst/input"; +import { + Card, + CardContent, + CardDescription, + CardFooter, + CardHeader, + CardTitle, +} from "@/components/ui/card"; +import { + Fieldset, + Field, + FieldGroup, + Label, + Description, +} from "@/components/catalyst/fieldset"; +import { Radio, RadioField, RadioGroup } from "@/components/catalyst/radio"; +import { Badge } from "@/components/ui/badge"; +import { EnvelopeIcon, UserPlusIcon } from "@heroicons/react/20/solid"; + +const inviteSchema = z.object({ + email: z.string().email({ + message: "Please enter a valid email address", + }), + role: z.enum(["ADMIN", "EDITOR", "VIEWER"], { + required_error: "Please select a role", + }), +}); + +type InviteFormValues = z.infer; + +export default function InviteUsersPage() { + const { token, currentWorkspace } = useAuth(); + const { toast } = useToast(); + const [isSubmitting, setIsSubmitting] = useState(false); + const [invitedUsers, setInvitedUsers] = useState<{ email: string; role: string; exists: boolean }[]>([]); + + const { + register, + handleSubmit, + reset, + formState: { errors }, + setValue, + watch, + } = useForm({ + resolver: zodResolver(inviteSchema), + defaultValues: { + email: "", + role: "VIEWER", + }, + }); + + const roleValue = watch("role"); + + const onSubmit = async (data: InviteFormValues) => { + if (!token || !currentWorkspace) { + toast({ + title: "Error", + description: "You must be logged in and have a workspace selected", + variant: "destructive", + }); + return; + } + + setIsSubmitting(true); + try { + const response = await apiCalls.inviteUserToWorkspace(token, { + email: data.email, + role: data.role, + workspace_name: currentWorkspace.workspace_name, + }); + + // Add to invited users list + setInvitedUsers([ + ...invitedUsers, + { + email: data.email, + role: data.role, + exists: response.user_exists, + }, + ]); + + // Reset form + reset(); + + toast({ + title: "Success", + description: `Invitation sent to ${data.email}`, + }); + } catch (error: any) { + toast({ + title: "Error", + description: error.message || "Failed to send invitation", + variant: "destructive", + }); + } finally { + setIsSubmitting(false); + } + }; + + return ( +
+
+ Invite Team Members +
+ +
+
+ + + Send Invitation + + Invite users to join your workspace: {currentWorkspace?.workspace_name} + + + +
+
+ + + + + {errors.email && ( +

+ {errors.email.message} +

+ )} + Enter the email address of the person you want to invite. +
+
+ + + + setValue("role", value as "ADMIN" | "VIEWER")} + > + + + + + Can manage workspace settings, invite members, and has full access to all resources. + + + + + + + Can only view resources, but cannot edit them or change any settings. + + + + {errors.role && ( +

+ {errors.role.message} +

+ )} +
+
+ +
+ +
+
+
+
+
+
+
+ ); +} diff --git a/frontend/src/app/(protected)/workspace/page.tsx b/frontend/src/app/(protected)/workspace/page.tsx new file mode 100644 index 0000000..daceb87 --- /dev/null +++ b/frontend/src/app/(protected)/workspace/page.tsx @@ -0,0 +1,116 @@ +"use client"; + +import { useEffect } from "react"; +import { useRouter } from "next/navigation"; +import { useAuth } from "@/utils/auth"; +import { Heading } from "@/components/catalyst/heading"; +import { Card, CardContent, CardHeader, CardTitle, CardDescription } from "@/components/ui/card"; +import { Plus, Users, Settings, Key } from "lucide-react"; +import { Button } from "@/components/catalyst/button"; + +export default function WorkspacePage() { + const { currentWorkspace } = useAuth(); + const router = useRouter(); + + if (!currentWorkspace) { + return ( +
+ + + + Something went wrong. Please try again later. + + + +
+ ); + } + console.log("Current Workspace:", currentWorkspace); + + return ( +
+
+ {currentWorkspace.workspace_name} +
+ +
+ + + Workspace Information + + +
+
+
Name:
+
{currentWorkspace.workspace_name}
+
+
+
API Quota:
+
{currentWorkspace.api_daily_quota.toLocaleString()} calls/day
+
+
+
Experiment Quota:
+
{currentWorkspace.content_quota.toLocaleString()} experiments
+
+
+
Created:
+
+ {new Date(currentWorkspace.created_datetime_utc).toLocaleDateString()} +
+
+
+
+
+ + + + API Key + + +
+
+ + + {currentWorkspace.api_key_first_characters} + {"•".repeat(27)} + +
+ +
+

+ Use this API key to authenticate your API requests. Keep it secret and secure. +

+
+
+
+ +
+ router.push('/workspace/invite')}> + +
+ Invite Team Members + + Invite colleagues to join your workspace + +
+ +
+
+ + router.push('/workspace/create')}> + +
+ Create New Workspace + + Create a new workspace for different projects + +
+ +
+
+
+
+ ); +} \ No newline at end of file diff --git a/frontend/src/app/(protected)/workspace/types.ts b/frontend/src/app/(protected)/workspace/types.ts new file mode 100644 index 0000000..65c4108 --- /dev/null +++ b/frontend/src/app/(protected)/workspace/types.ts @@ -0,0 +1,51 @@ +export enum UserRoles { + ADMIN = "ADMIN", + EDITOR = "EDITOR", + VIEWER = "VIEWER", +} + +export interface Workspace { + workspace_id: number; + workspace_name: string; + api_key_first_characters: string; + api_key_updated_datetime_utc: string; + api_daily_quota: number; + content_quota: number; + created_datetime_utc: string; + updated_datetime_utc: string; + is_default: boolean; +} + +export interface WorkspaceCreate { + workspace_name: string; + api_daily_quota?: number; + content_quota?: number; +} + +export interface WorkspaceUpdate { + workspace_name?: string; + api_daily_quota?: number; + content_quota?: number; +} + +export interface WorkspaceKeyResponse { + new_api_key: string; + workspace_name: string; +} + +export interface WorkspaceInvite { + email: string; + role: UserRoles; + workspace_name: string; +} + +export interface WorkspaceInviteResponse { + message: string; + email: string; + workspace_name: string; + user_exists: boolean; +} + +export interface WorkspaceSwitch { + workspace_name: string; +} diff --git a/frontend/src/components/WorkspaceSelector.tsx b/frontend/src/components/WorkspaceSelector.tsx new file mode 100644 index 0000000..c958ea1 --- /dev/null +++ b/frontend/src/components/WorkspaceSelector.tsx @@ -0,0 +1,142 @@ +"use client"; + +import React, { useState } from "react"; +import { useAuth } from "@/utils/auth"; +import { Button } from "@/components/catalyst/button"; +import { + Dialog, + DialogActions, + DialogBody, + DialogDescription, + DialogTitle, +} from "@/components/catalyst/dialog"; +import { BuildingOfficeIcon, ChevronUpDownIcon, PlusIcon } from "@heroicons/react/20/solid"; +import { + DropdownItem, + DropdownLabel, + DropdownMenu, + DropdownButton, + DropdownDivider, + Dropdown, +} from "@/components/catalyst/dropdown"; +import { useToast } from "@/hooks/use-toast"; +import { useRouter } from "next/navigation"; + +export default function WorkspaceSelector() { + const { currentWorkspace, workspaces, switchWorkspace, isLoading } = useAuth(); + const [isOpen, setIsOpen] = useState(false); + const { toast } = useToast(); + const router = useRouter(); + + const handleSwitchWorkspace = async (workspaceName: string) => { + try { + await switchWorkspace(workspaceName); + toast({ + title: "Workspace Changed", + description: `Switched to workspace: ${workspaceName}`, + }); + } catch (error) { + toast({ + title: "Error", + description: "Failed to switch workspace", + variant: "destructive", + }); + } + }; + + const handleCreateWorkspace = () => { + router.push("/workspace/create"); + }; + + if (isLoading || !currentWorkspace) { + return ( + + ); + } + + return ( +
+ + + + + {currentWorkspace.workspace_name} + + + + + + {workspaces.map((workspace) => ( + handleSwitchWorkspace(workspace.workspace_name)} + > + + {workspace.workspace_name} + + ))} + + + + + + Create New Workspace + + + + + setIsOpen(false)}> + Switch Workspace + + Select a workspace to switch to + + +
+ {workspaces.map((workspace) => ( +
{ + handleSwitchWorkspace(workspace.workspace_name); + setIsOpen(false); + }} + > +
{workspace.workspace_name}
+ {workspace.workspace_id === currentWorkspace.workspace_id && ( +
Current
+ )} +
+ ))} +
+
+ + + + +
+
+ ); +} \ No newline at end of file diff --git a/frontend/src/components/sidebar.tsx b/frontend/src/components/sidebar.tsx index fe2d64f..482b0a6 100644 --- a/frontend/src/components/sidebar.tsx +++ b/frontend/src/components/sidebar.tsx @@ -10,6 +10,7 @@ import { SidebarLabel, SidebarSection, SidebarSpacer, + SidebarDivider, } from "@/components/catalyst/sidebar"; import { Avatar } from "@/components/catalyst/avatar"; import { Dropdown, DropdownButton } from "@/components/catalyst/dropdown"; @@ -22,11 +23,14 @@ import { MagnifyingGlassIcon, QuestionMarkCircleIcon, SparklesIcon, + BuildingOfficeIcon, + UserPlusIcon, } from "@heroicons/react/20/solid"; import { useAuth } from "@/utils/auth"; +import WorkspaceSelector from "./WorkspaceSelector"; export const SidebarComponent = (): React.ReactNode => { - const { user } = useAuth(); + const { user, currentWorkspace } = useAuth(); return ( @@ -43,6 +47,19 @@ export const SidebarComponent = (): React.ReactNode => { + + Workspace + + + + Manage Workspace + + + + Invite Members + + + {navItems.map((item) => ( @@ -51,7 +68,7 @@ export const SidebarComponent = (): React.ReactNode => { ))} - + {/* New experiments Modifying voice of chatbot @@ -65,7 +82,7 @@ export const SidebarComponent = (): React.ReactNode => { When to send the message - + */} diff --git a/frontend/src/utils/api.ts b/frontend/src/utils/api.ts index 4c8d0df..23ca704 100644 --- a/frontend/src/utils/api.ts +++ b/frontend/src/utils/api.ts @@ -90,42 +90,97 @@ const registerUser = async (username: string, password: string) => { } }; -const requestPasswordReset = async (username: string) => { +const getCurrentWorkspace = async (token: string | null) => { try { - const response = await api.post("/request-password-reset", { username }); + const response = await api.get("/workspace/current", { + headers: { + Authorization: `Bearer ${token}`, + }, + }); return response.data; } catch (error) { - throw new Error("Error requesting password reset"); + throw new Error("Error fetching current workspace"); } }; -const resetPassword = async (token: string, newPassword: string) => { +const getAllWorkspaces = async (token: string | null) => { try { - const response = await api.post("/reset-password", { - token, - new_password: newPassword + const response = await api.get("/workspace/", { + headers: { + Authorization: `Bearer ${token}`, + }, }); return response.data; } catch (error) { - throw new Error("Error resetting password"); + throw new Error("Error fetching workspaces"); } }; -const verifyEmail = async (token: string) => { +const createWorkspace = async (token: string | null, workspaceData: any) => { try { - const response = await api.post("/verify-email", { token }); + const response = await api.post("/workspace/", workspaceData, { + headers: { + Authorization: `Bearer ${token}`, + }, + }); return response.data; } catch (error) { - throw new Error("Error verifying email"); + throw new Error("Error creating workspace"); } }; -const resendVerification = async (username: string) => { +const updateWorkspace = async (token: string | null, workspaceId: number, workspaceData: any) => { try { - const response = await api.post("/resend-verification", { username }); + const response = await api.put(`/workspace/${workspaceId}`, workspaceData, { + headers: { + Authorization: `Bearer ${token}`, + }, + }); return response.data; } catch (error) { - throw new Error("Error resending verification email"); + throw new Error("Error updating workspace"); + } +}; + +const switchWorkspace = async (token: string | null, workspaceName: string) => { + try { + const response = await api.post("/workspace/switch", + { workspace_name: workspaceName }, + { + headers: { + Authorization: `Bearer ${token}`, + }, + } + ); + return response.data; + } catch (error) { + throw new Error("Error switching workspace"); + } +}; + +const rotateWorkspaceApiKey = async (token: string | null) => { + try { + const response = await api.put("/workspace/rotate-key", {}, { + headers: { + Authorization: `Bearer ${token}`, + }, + }); + return response.data; + } catch (error) { + throw new Error("Error rotating workspace API key"); + } +}; + +const inviteUserToWorkspace = async (token: string | null, inviteData: any) => { + try { + const response = await api.post("/workspace/invite", inviteData, { + headers: { + Authorization: `Bearer ${token}`, + }, + }); + return response.data; + } catch (error) { + throw new Error("Error inviting user to workspace"); } }; @@ -177,5 +232,12 @@ export const apiCalls = { resetPassword, verifyEmail, resendVerification, + getCurrentWorkspace, + getAllWorkspaces, + createWorkspace, + updateWorkspace, + switchWorkspace, + rotateWorkspaceApiKey, + inviteUserToWorkspace, }; export default api; diff --git a/frontend/src/utils/auth.tsx b/frontend/src/utils/auth.tsx index ee07bad..814cf98 100644 --- a/frontend/src/utils/auth.tsx +++ b/frontend/src/utils/auth.tsx @@ -3,14 +3,24 @@ import { apiCalls } from "@/utils/api"; import { useRouter, useSearchParams } from "next/navigation"; import { ReactNode, createContext, useContext, useState, useEffect } from "react"; +type Workspace = { + workspace_id: number; + workspace_name: string; + api_key_first_characters: string; + is_default: boolean; +}; + type AuthContextType = { token: string | null; user: string | null; isVerified: boolean; isLoading: boolean; - login: (username: string, password: string) => void; + currentWorkspace: Workspace | null; + workspaces: Workspace[]; + login: (username: string, password: string) => Promise; logout: () => void; loginError: string | null; + switchWorkspace: (workspaceName: string) => Promise; loginGoogle: ({ client_id, credential, @@ -46,10 +56,35 @@ const AuthProvider = ({ children }: AuthProviderProps) => { const [isVerified, setIsVerified] = useState(false); const [isLoading, setIsLoading] = useState(!!getInitialToken()); const [loginError, setLoginError] = useState(null); + const [currentWorkspace, setCurrentWorkspace] = useState(null); + const [workspaces, setWorkspaces] = useState([]); const searchParams = useSearchParams(); const router = useRouter(); + useEffect(() => { + const loadWorkspaceInfo = async () => { + if (token) { + try { + setIsLoading(true); + // Fetch current workspace + const currentWorkspaceData = await apiCalls.getCurrentWorkspace(token); + setCurrentWorkspace(currentWorkspaceData); + + // Fetch all workspaces + const workspacesData = await apiCalls.getAllWorkspaces(token); + setWorkspaces(workspacesData); + } catch (error) { + console.error("Error loading workspace info:", error); + } finally { + setIsLoading(false); + } + } + }; + + loadWorkspaceInfo(); + }, [token]); + // Check verification status on init if token exists useEffect(() => { const checkVerificationStatus = async () => { @@ -101,6 +136,16 @@ const AuthProvider = ({ children }: AuthProviderProps) => { } } + try { + const currentWorkspaceData = await apiCalls.getCurrentWorkspace(access_token); + setCurrentWorkspace(currentWorkspaceData); + + const workspacesData = await apiCalls.getAllWorkspaces(access_token); + setWorkspaces(workspacesData); + } catch (error) { + console.error("Error loading workspace info:", error); + } + // Redirect to verification page if not verified, otherwise to original destination if (response.is_verified === false) { router.push("/verification-required"); @@ -123,6 +168,26 @@ const AuthProvider = ({ children }: AuthProviderProps) => { } }; + const switchWorkspace = async (workspaceName: string) => { + try { + setIsLoading(true); + const response = await apiCalls.switchWorkspace(token, workspaceName); + + localStorage.setItem("ee-token", response.access_token); + setToken(response.access_token); + + const currentWorkspaceData = await apiCalls.getCurrentWorkspace(response.access_token); + setCurrentWorkspace(currentWorkspaceData); + + return response; + } catch (error) { + console.error("Error switching workspace:", error); + throw error; + } finally { + setIsLoading(false); + } + }; + const loginGoogle = async ({ client_id, credential, @@ -151,6 +216,16 @@ const AuthProvider = ({ children }: AuthProviderProps) => { setIsVerified(true); + try { + const currentWorkspaceData = await apiCalls.getCurrentWorkspace(access_token); + setCurrentWorkspace(currentWorkspaceData); + + const workspacesData = await apiCalls.getAllWorkspaces(access_token); + setWorkspaces(workspacesData); + } catch (error) { + console.error("Error loading workspace info:", error); + } + router.push(sourcePage); } catch (error) { setLoginError("Invalid Google credentials"); @@ -167,6 +242,8 @@ const AuthProvider = ({ children }: AuthProviderProps) => { setUser(null); setToken(null); setIsVerified(false); + setCurrentWorkspace(null); + setWorkspaces([]); router.push("/login"); }; @@ -175,9 +252,12 @@ const AuthProvider = ({ children }: AuthProviderProps) => { user, isVerified, isLoading, + currentWorkspace, + workspaces, login, loginError, loginGoogle, + switchWorkspace, logout, }; From 219aea776ed42d8e1c20422c82690e6d26e4ad52 Mon Sep 17 00:00:00 2001 From: Jay Prakash <0freerunning@gmail.com> Date: Tue, 22 Apr 2025 03:40:04 +0530 Subject: [PATCH 05/74] File formating --- backend/app/__init__.py | 2 +- backend/app/auth/dependencies.py | 2 +- backend/app/auth/routers.py | 19 ++--- backend/app/contextual_mab/models.py | 10 +-- backend/app/contextual_mab/routers.py | 68 ++++++----------- backend/app/email.py | 13 +++- backend/app/mab/models.py | 10 +-- backend/app/mab/routers.py | 72 +++++++----------- backend/app/models.py | 6 +- backend/app/users/models.py | 3 +- backend/app/users/routers.py | 3 +- backend/app/workspaces/models.py | 15 ++-- backend/app/workspaces/routers.py | 66 ++++++++-------- backend/app/workspaces/schemas.py | 11 ++- backend/app/workspaces/utils.py | 6 +- .../949c9fc0461d_workspace_relationship.py | 20 +++-- .../versions/977e7e73ce06_workspace_model.py | 75 +++++++++++-------- .../app/(protected)/workspace/create/page.tsx | 10 +-- .../src/app/(protected)/workspace/page.tsx | 2 +- .../src/app/(protected)/workspace/types.ts | 1 - frontend/src/components/WorkspaceSelector.tsx | 6 +- frontend/src/utils/api.ts | 2 +- frontend/src/utils/auth.tsx | 12 +-- 23 files changed, 204 insertions(+), 230 deletions(-) diff --git a/backend/app/__init__.py b/backend/app/__init__.py index 2653972..0f97151 100644 --- a/backend/app/__init__.py +++ b/backend/app/__init__.py @@ -10,8 +10,8 @@ from .users.routers import ( router as users_router, ) # to avoid circular imports -from .workspaces.routers import router as workspaces_router from .utils import setup_logger +from .workspaces.routers import router as workspaces_router logger = setup_logger() diff --git a/backend/app/auth/dependencies.py b/backend/app/auth/dependencies.py index d34a593..3fcaefd 100644 --- a/backend/app/auth/dependencies.py +++ b/backend/app/auth/dependencies.py @@ -198,7 +198,7 @@ def create_access_token(username: str, workspace_name: str = None) -> str: payload["iat"] = datetime.now(timezone.utc) payload["sub"] = username payload["type"] = "access_token" - + if workspace_name: payload["workspace_name"] = workspace_name diff --git a/backend/app/auth/routers.py b/backend/app/auth/routers.py index c989ae8..04346d9 100644 --- a/backend/app/auth/routers.py +++ b/backend/app/auth/routers.py @@ -7,7 +7,6 @@ from sqlalchemy.ext.asyncio import AsyncSession from ..config import DEFAULT_API_QUOTA, DEFAULT_EXPERIMENTS_QUOTA - from ..database import get_async_session, get_redis from ..email import EmailService from ..users.models import ( @@ -101,12 +100,12 @@ async def login_google( # Import here to avoid circular imports from ..workspaces.models import ( - create_user_workspace_role, + UserRoles, + create_user_workspace_role, get_user_default_workspace, - UserRoles ) from ..workspaces.utils import create_workspace - + user_email = idinfo["email"] user = await authenticate_or_create_google_user( request=request, google_email=user_email, asession=asession @@ -118,15 +117,17 @@ async def login_google( ) user_db = await get_user_by_username(username=user_email, asession=asession) - + # Create default workspace if user is new (has no workspaces) try: - default_workspace = await get_user_default_workspace(asession=asession, user_db=user_db) + default_workspace = await get_user_default_workspace( + asession=asession, user_db=user_db + ) default_workspace_name = default_workspace.workspace_name except Exception: # User doesn't have a default workspace, create one default_workspace_name = f"{user_email}'s Workspace" - + # Create default workspace workspace_db, _ = await create_workspace( api_daily_quota=DEFAULT_API_QUOTA, @@ -137,9 +138,9 @@ async def login_google( username=user_email, workspace_name=default_workspace_name, ), - is_default=True + is_default=True, ) - + await create_user_workspace_role( asession=asession, is_default_workspace=True, diff --git a/backend/app/contextual_mab/models.py b/backend/app/contextual_mab/models.py index da4d699..f88eaa4 100644 --- a/backend/app/contextual_mab/models.py +++ b/backend/app/contextual_mab/models.py @@ -265,10 +265,7 @@ async def get_all_contextual_mabs( async def get_contextual_mab_by_id( - experiment_id: int, - user_id: int, - workspace_id: int, - asession: AsyncSession + experiment_id: int, user_id: int, workspace_id: int, asession: AsyncSession ) -> ContextualBanditDB | None: """ Get the contextual experiment by id. @@ -284,10 +281,7 @@ async def get_contextual_mab_by_id( async def delete_contextual_mab_by_id( - experiment_id: int, - user_id: int, - workspace_id: int, - asession: AsyncSession + experiment_id: int, user_id: int, workspace_id: int, asession: AsyncSession ) -> None: """ Delete the contextual experiment by id. diff --git a/backend/app/contextual_mab/routers.py b/backend/app/contextual_mab/routers.py index c56aa32..e9484bf 100644 --- a/backend/app/contextual_mab/routers.py +++ b/backend/app/contextual_mab/routers.py @@ -4,14 +4,13 @@ from fastapi.exceptions import HTTPException from sqlalchemy.ext.asyncio import AsyncSession -from ..workspaces.models import get_user_default_workspace, get_user_role_in_workspace -from ..workspaces.schemas import UserRoles - from ..auth.dependencies import authenticate_key, get_verified_user from ..database import get_async_session from ..models import get_notifications_from_db, save_notifications_to_db from ..schemas import ContextType, NotificationsResponse, Outcome from ..users.models import UserDB +from ..workspaces.models import get_user_default_workspace, get_user_role_in_workspace +from ..workspaces.schemas import UserRoles from .models import ( delete_contextual_mab_by_id, get_all_contextual_mabs, @@ -45,22 +44,19 @@ async def create_contextual_mabs( Create a new contextual experiment with different priors for each context. """ workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - + user_role = await get_user_role_in_workspace( asession=asession, user_db=user_db, workspace_db=workspace_db ) - + if user_role != UserRoles.ADMIN: raise HTTPException( status_code=403, detail="Only workspace administrators can create experiments.", ) - + cmab = await save_contextual_mab_to_db( - experiment, - user_db.user_id, - workspace_db.workspace_id, - asession + experiment, user_db.user_id, workspace_db.workspace_id, asession ) notifications = await save_notifications_to_db( experiment_id=cmab.experiment_id, @@ -82,11 +78,9 @@ async def get_contextual_mabs( Get details of all experiments. """ workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - + experiments = await get_all_contextual_mabs( - user_db.user_id, - workspace_db.workspace_id, - asession + user_db.user_id, workspace_db.workspace_id, asession ) all_experiments = [] for exp in experiments: @@ -124,10 +118,7 @@ async def get_contextual_mab( workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) experiment = await get_contextual_mab_by_id( - experiment_id, - user_db.user_id, - workspace_db.workspace_id, - asession + experiment_id, user_db.user_id, workspace_db.workspace_id, asession ) if experiment is None: raise HTTPException( @@ -154,33 +145,29 @@ async def delete_contextual_mab( Delete the experiment with the provided `experiment_id`. """ try: - workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - + workspace_db = await get_user_default_workspace( + asession=asession, user_db=user_db + ) + user_role = await get_user_role_in_workspace( asession=asession, user_db=user_db, workspace_db=workspace_db ) - + if user_role != UserRoles.ADMIN: raise HTTPException( status_code=403, detail="Only workspace administrators can delete experiments.", ) - + experiment = await get_contextual_mab_by_id( - experiment_id, - user_db.user_id, - workspace_db.workspace_id, - asession + experiment_id, user_db.user_id, workspace_db.workspace_id, asession ) if experiment is None: raise HTTPException( status_code=404, detail=f"Experiment with id {experiment_id} not found" ) await delete_contextual_mab_by_id( - experiment_id, - user_db.user_id, - workspace_db.workspace_id, - asession + experiment_id, user_db.user_id, workspace_db.workspace_id, asession ) return {"detail": f"Experiment {experiment_id} deleted successfully."} except Exception as e: @@ -198,12 +185,9 @@ async def draw_arm( Get which arm to pull next for provided experiment. """ workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - + experiment = await get_contextual_mab_by_id( - experiment_id, - user_db.user_id, - workspace_db.workspace_id, - asession + experiment_id, user_db.user_id, workspace_db.workspace_id, asession ) if experiment is None: @@ -247,13 +231,10 @@ async def update_arm( `experiment_id` based on the `outcome`. """ workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - + # Get the experiment and do checks experiment = await get_contextual_mab_by_id( - experiment_id, - user_db.user_id, - workspace_db.workspace_id, - asession + experiment_id, user_db.user_id, workspace_db.workspace_id, asession ) if experiment is None: raise HTTPException( @@ -346,12 +327,9 @@ async def get_outcomes( Get the outcomes for the experiment. """ workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - + experiment = await get_contextual_mab_by_id( - experiment_id, - user_db.user_id, - workspace_db.workspace_id, - asession + experiment_id, user_db.user_id, workspace_db.workspace_id, asession ) if not experiment: raise HTTPException( diff --git a/backend/app/email.py b/backend/app/email.py index 9770090..d75d90a 100644 --- a/backend/app/email.py +++ b/backend/app/email.py @@ -125,7 +125,12 @@ async def send_password_reset_email( return await self._send_email(email, subject, html_body, text_body) async def send_workspace_invitation_email( - self, email: str, username: str, inviter_email: str, workspace_name: str, user_exists: bool + self, + email: str, + username: str, + inviter_email: str, + workspace_name: str, + user_exists: bool, ) -> Dict[str, Any]: """ Send workspace invitation email @@ -150,7 +155,7 @@ async def send_workspace_invitation_email( Hello {username}, You have been invited by {inviter_email} to join the workspace "{workspace_name}". - + You have been added to this workspace. Log in to access it. {FRONTEND_URL}/login @@ -175,12 +180,12 @@ async def send_workspace_invitation_email( Hello, You have been invited by {inviter_email} to join the workspace "{workspace_name}". - + You need to create an account to join this workspace. {FRONTEND_URL}/signup """ - + return await self._send_email(email, subject, html_body, text_body) async def _send_email( diff --git a/backend/app/mab/models.py b/backend/app/mab/models.py index 377498d..416193a 100644 --- a/backend/app/mab/models.py +++ b/backend/app/mab/models.py @@ -198,10 +198,7 @@ async def get_all_mabs( async def get_mab_by_id( - experiment_id: int, - user_id: int, - workspace_id: int, - asession: AsyncSession + experiment_id: int, user_id: int, workspace_id: int, asession: AsyncSession ) -> MultiArmedBanditDB | None: """ Get the experiment by id. @@ -217,10 +214,7 @@ async def get_mab_by_id( async def delete_mab_by_id( - experiment_id: int, - user_id: int, - workspace_id: int, - asession: AsyncSession + experiment_id: int, user_id: int, workspace_id: int, asession: AsyncSession ) -> None: """ Delete the experiment by id. diff --git a/backend/app/mab/routers.py b/backend/app/mab/routers.py index fd5af38..f700d50 100644 --- a/backend/app/mab/routers.py +++ b/backend/app/mab/routers.py @@ -5,14 +5,13 @@ from fastapi.exceptions import HTTPException from sqlalchemy.ext.asyncio import AsyncSession -from ..workspaces.models import get_user_default_workspace, get_user_role_in_workspace -from ..workspaces.schemas import UserRoles - from ..auth.dependencies import authenticate_key, get_verified_user from ..database import get_async_session from ..models import get_notifications_from_db, save_notifications_to_db from ..schemas import NotificationsResponse, Outcome, RewardLikelihood from ..users.models import UserDB +from ..workspaces.models import get_user_default_workspace, get_user_role_in_workspace +from ..workspaces.schemas import UserRoles from .models import ( delete_mab_by_id, get_all_mabs, @@ -44,24 +43,21 @@ async def create_mab( Create a new experiment in the user's current workspace. """ workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - + user_role = await get_user_role_in_workspace( asession=asession, user_db=user_db, workspace_db=workspace_db ) - + if user_role != UserRoles.ADMIN: raise HTTPException( status_code=403, detail="Only workspace administrators can create experiments.", ) - + mab = await save_mab_to_db( - experiment, - user_db.user_id, - workspace_db.workspace_id, - asession + experiment, user_db.user_id, workspace_db.workspace_id, asession ) - + notifications = await save_notifications_to_db( experiment_id=mab.experiment_id, user_id=user_db.user_id, @@ -84,11 +80,9 @@ async def get_mabs( Get details of all experiments in the user's current workspace. """ workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - + experiments = await get_all_mabs( - user_db.user_id, - workspace_db.workspace_id, - asession + user_db.user_id, workspace_db.workspace_id, asession ) all_experiments = [] @@ -123,12 +117,9 @@ async def get_mab( Get details of experiment with the provided `experiment_id`. """ workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - + experiment = await get_mab_by_id( - experiment_id, - user_db.user_id, - workspace_db.workspace_id, - asession + experiment_id, user_db.user_id, workspace_db.workspace_id, asession ) if experiment is None: @@ -157,33 +148,29 @@ async def delete_mab( Delete the experiment with the provided `experiment_id`. """ try: - workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - + workspace_db = await get_user_default_workspace( + asession=asession, user_db=user_db + ) + user_role = await get_user_role_in_workspace( asession=asession, user_db=user_db, workspace_db=workspace_db ) - + if user_role != UserRoles.ADMIN: raise HTTPException( status_code=403, detail="Only workspace administrators can delete experiments.", ) - + experiment = await get_mab_by_id( - experiment_id, - user_db.user_id, - workspace_db.workspace_id, - asession + experiment_id, user_db.user_id, workspace_db.workspace_id, asession ) if experiment is None: raise HTTPException( status_code=404, detail=f"Experiment with id {experiment_id} not found" ) await delete_mab_by_id( - experiment_id, - user_db.user_id, - workspace_db.workspace_id, - asession + experiment_id, user_db.user_id, workspace_db.workspace_id, asession ) return {"message": f"Experiment with id {experiment_id} deleted successfully."} except Exception as e: @@ -200,12 +187,9 @@ async def draw_arm( Get which arm to pull next for provided experiment. """ workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - + experiment = await get_mab_by_id( - experiment_id, - user_db.user_id, - workspace_db.workspace_id, - asession + experiment_id, user_db.user_id, workspace_db.workspace_id, asession ) if experiment is None: raise HTTPException( @@ -229,12 +213,9 @@ async def update_arm( `experiment_id` based on the `outcome`. """ workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - + experiment = await get_mab_by_id( - experiment_id, - user_db.user_id, - workspace_db.workspace_id, - asession + experiment_id, user_db.user_id, workspace_db.workspace_id, asession ) if experiment is None: raise HTTPException( @@ -303,12 +284,9 @@ async def get_outcomes( Get the outcomes for the experiment. """ workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - + experiment = await get_mab_by_id( - experiment_id, - user_db.user_id, - workspace_db.workspace_id, - asession + experiment_id, user_db.user_id, workspace_db.workspace_id, asession ) if not experiment: raise HTTPException( diff --git a/backend/app/models.py b/backend/app/models.py index 5314ebb..162461b 100644 --- a/backend/app/models.py +++ b/backend/app/models.py @@ -44,8 +44,10 @@ class ExperimentBaseDB(Base): last_trial_datetime_utc: Mapped[datetime] = mapped_column( DateTime(timezone=True), nullable=True ) - workspace: Mapped["WorkspaceDB"] = relationship("WorkspaceDB", back_populates="experiments") - + workspace: Mapped["WorkspaceDB"] = relationship( + "WorkspaceDB", back_populates="experiments" + ) + __mapper_args__ = { "polymorphic_identity": "experiment", "polymorphic_on": "exp_type", diff --git a/backend/app/users/models.py b/backend/app/users/models.py index 6f8bfc2..9d68acd 100644 --- a/backend/app/users/models.py +++ b/backend/app/users/models.py @@ -11,10 +11,9 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Mapped, mapped_column, relationship -from ..workspaces.models import UserWorkspaceDB, WorkspaceDB - from ..models import Base from ..utils import get_key_hash, get_password_salted_hash, get_random_string +from ..workspaces.models import UserWorkspaceDB, WorkspaceDB from .schemas import UserCreate, UserCreateWithPassword PASSWORD_LENGTH = 12 diff --git a/backend/app/users/routers.py b/backend/app/users/routers.py index 90844bb..dee1776 100644 --- a/backend/app/users/routers.py +++ b/backend/app/users/routers.py @@ -6,10 +6,9 @@ from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession -from ..config import DEFAULT_API_QUOTA, DEFAULT_EXPERIMENTS_QUOTA - from ..auth.dependencies import get_current_user, get_verified_user from ..auth.utils import generate_verification_token +from ..config import DEFAULT_API_QUOTA, DEFAULT_EXPERIMENTS_QUOTA from ..database import get_async_session, get_redis from ..email import EmailService from ..users.models import ( diff --git a/backend/app/workspaces/models.py b/backend/app/workspaces/models.py index f7b37a2..9547917 100644 --- a/backend/app/workspaces/models.py +++ b/backend/app/workspaces/models.py @@ -1,14 +1,12 @@ from datetime import datetime, timezone -from typing import Sequence, TYPE_CHECKING +from typing import TYPE_CHECKING, Sequence -import sqlalchemy.sql.functions as func from sqlalchemy import ( Boolean, DateTime, Enum, ForeignKey, Integer, - Row, String, case, exists, @@ -16,15 +14,12 @@ text, update, ) -from sqlalchemy.exc import NoResultFound from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Mapped, mapped_column, relationship -from ..users.schemas import UserCreate - from ..models import Base, ExperimentBaseDB -from ..utils import get_key_hash -from .schemas import UserRoles, UserCreateWithCode +from ..users.schemas import UserCreate +from .schemas import UserCreateWithCode, UserRoles # Use TYPE_CHECKING for type hints without runtime imports if TYPE_CHECKING: @@ -193,7 +188,7 @@ async def update_user_default_workspace( (UserWorkspaceDB.workspace_id == workspace_db.workspace_id, True), else_=False, ), - updated_datetime_utc=datetime.now(timezone.utc) + updated_datetime_utc=datetime.now(timezone.utc), ) ) @@ -275,7 +270,7 @@ async def add_existing_user_to_workspace( """Add an existing user to a workspace.""" # Import here to avoid circular imports from ..users.models import get_user_by_username - + assert user.role is not None user.is_default_workspace = user.is_default_workspace or False diff --git a/backend/app/workspaces/routers.py b/backend/app/workspaces/routers.py index d36fb1a..2cb8800 100644 --- a/backend/app/workspaces/routers.py +++ b/backend/app/workspaces/routers.py @@ -2,9 +2,9 @@ from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, status from fastapi.exceptions import HTTPException +from redis.asyncio import Redis from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession -from redis.asyncio import Redis from ..auth.dependencies import ( create_access_token, @@ -17,7 +17,6 @@ from ..users.models import ( UserDB, UserNotFoundError, - save_user_to_db, get_user_by_username, ) from ..users.schemas import UserCreate @@ -28,9 +27,7 @@ get_user_default_workspace, get_user_role_in_workspace, get_user_workspaces, - get_workspaces_by_user_role, update_user_default_workspace, - user_has_admin_role_in_any_workspace, ) from .schemas import ( UserRoles, @@ -113,7 +110,7 @@ async def create_workspace_endpoint( ), workspace_db=workspace_db, ) - + return WorkspaceRetrieve( api_daily_quota=workspace_db.api_daily_quota, api_key_first_characters=workspace_db.api_key_first_characters, @@ -138,8 +135,10 @@ async def retrieve_all_workspaces( asession: AsyncSession = Depends(get_async_session), ) -> List[WorkspaceRetrieve]: """Return a list of all workspaces the user has access to.""" - user_workspaces = await get_user_workspaces(asession=asession, user_db=calling_user_db) - + user_workspaces = await get_user_workspaces( + asession=asession, user_db=calling_user_db + ) + return [ WorkspaceRetrieve( api_daily_quota=workspace_db.api_daily_quota, @@ -166,7 +165,7 @@ async def get_current_workspace( workspace_db = await get_user_default_workspace( asession=asession, user_db=calling_user_db ) - + return WorkspaceRetrieve( api_daily_quota=workspace_db.api_daily_quota, api_key_first_characters=workspace_db.api_key_first_characters, @@ -207,7 +206,7 @@ async def switch_workspace( user_role = await get_user_role_in_workspace( asession=asession, user_db=calling_user_db, workspace_db=workspace_db ) - + if user_role is None: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -241,25 +240,25 @@ async def rotate_workspace_api_key( workspace_db = await get_user_default_workspace( asession=asession, user_db=calling_user_db ) - + # Verify user is an admin in this workspace user_role = await get_user_role_in_workspace( asession=asession, user_db=calling_user_db, workspace_db=workspace_db ) - + if user_role != UserRoles.ADMIN: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Only workspace administrators can rotate API keys.", ) - + # Generate and update the API key new_api_key = generate_key() asession.add(workspace_db) workspace_db = await update_workspace_api_key( asession=asession, new_api_key=new_api_key, workspace_db=workspace_db ) - + return WorkspaceKeyResponse( new_api_key=new_api_key, workspace_name=workspace_db.workspace_name, @@ -284,18 +283,18 @@ async def retrieve_workspace_by_workspace_id( workspace_db = await get_workspace_by_workspace_id( asession=asession, workspace_id=workspace_id ) - + # Check if user has access to this workspace user_role = await get_user_role_in_workspace( asession=asession, user_db=calling_user_db, workspace_db=workspace_db ) - + if user_role is None: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=f"User does not have access to workspace with ID {workspace_id}.", ) - + return WorkspaceRetrieve( api_daily_quota=workspace_db.api_daily_quota, api_key_first_characters=workspace_db.api_key_first_characters, @@ -327,20 +326,23 @@ async def update_workspace_endpoint( workspace_db = await get_workspace_by_workspace_id( asession=asession, workspace_id=workspace_id ) - + # Verify user is an admin in this workspace user_role = await get_user_role_in_workspace( asession=asession, user_db=calling_user_db, workspace_db=workspace_db ) - + if user_role != UserRoles.ADMIN: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Only workspace administrators can update workspace details.", ) - + # Check if the new workspace name is valid - if workspace_update.workspace_name and workspace_update.workspace_name != workspace_db.workspace_name: + if ( + workspace_update.workspace_name + and workspace_update.workspace_name != workspace_db.workspace_name + ): if not await is_workspace_name_valid( asession=asession, workspace_name=workspace_update.workspace_name ): @@ -348,13 +350,13 @@ async def update_workspace_endpoint( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Workspace with name '{workspace_update.workspace_name}' already exists.", ) - + # Update the workspace asession.add(workspace_db) updated_workspace = await update_workspace_name_and_quotas( asession=asession, workspace=workspace_update, workspace_db=workspace_db ) - + return WorkspaceRetrieve( api_daily_quota=updated_workspace.api_daily_quota, api_key_first_characters=updated_workspace.api_key_first_characters, @@ -393,25 +395,25 @@ async def invite_user_to_workspace( workspace_db = await get_workspace_by_workspace_name( asession=asession, workspace_name=invite.workspace_name ) - + # Check if it's a default workspace (users can't invite others to default workspaces) if workspace_db.is_default: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Users cannot be invited to default workspaces.", ) - + # Verify caller is an admin in this workspace user_role = await get_user_role_in_workspace( asession=asession, user_db=calling_user_db, workspace_db=workspace_db ) - + if user_role != UserRoles.ADMIN: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Only workspace administrators can invite users.", ) - + # Check if the invited user exists user_exists = False try: @@ -419,7 +421,7 @@ async def invite_user_to_workspace( username=invite.email, asession=asession ) user_exists = True - + # Add existing user to workspace await add_existing_user_to_workspace( asession=asession, @@ -430,7 +432,7 @@ async def invite_user_to_workspace( ), workspace_db=workspace_db, ) - + # Send invitation email to existing user background_tasks.add_task( email_service.send_workspace_invitation_email, @@ -440,14 +442,14 @@ async def invite_user_to_workspace( workspace_db.workspace_name, True, # user exists ) - + return WorkspaceInviteResponse( message=f"User {invite.email} has been added to workspace '{workspace_db.workspace_name}'.", email=invite.email, workspace_name=workspace_db.workspace_name, user_exists=True, ) - + except UserNotFoundError: # User doesn't exist, send invitation to create account background_tasks.add_task( @@ -458,14 +460,14 @@ async def invite_user_to_workspace( workspace_db.workspace_name, False, # user doesn't exist ) - + return WorkspaceInviteResponse( message=f"Invitation sent to {invite.email} to join workspace '{workspace_db.workspace_name}'.", email=invite.email, workspace_name=workspace_db.workspace_name, user_exists=False, ) - + except WorkspaceNotFoundError as e: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, diff --git a/backend/app/workspaces/schemas.py b/backend/app/workspaces/schemas.py index 5df8378..512a78b 100644 --- a/backend/app/workspaces/schemas.py +++ b/backend/app/workspaces/schemas.py @@ -1,6 +1,6 @@ from datetime import datetime from enum import Enum -from typing import Optional, List +from typing import List, Optional from pydantic import BaseModel, ConfigDict, EmailStr @@ -27,6 +27,7 @@ class UserRoles(str, Enum): class WorkspaceCreate(BaseModel): """Pydantic model for workspace creation.""" + api_daily_quota: int | None = None content_quota: int | None = None workspace_name: str @@ -36,6 +37,7 @@ class WorkspaceCreate(BaseModel): class WorkspaceKeyResponse(BaseModel): """Pydantic model for updating workspace API key.""" + new_api_key: str workspace_name: str @@ -44,6 +46,7 @@ class WorkspaceKeyResponse(BaseModel): class WorkspaceRetrieve(BaseModel): """Pydantic model for workspace retrieval.""" + api_daily_quota: Optional[int] = None api_key_first_characters: Optional[str] = None api_key_updated_datetime_utc: Optional[datetime] = None @@ -59,6 +62,7 @@ class WorkspaceRetrieve(BaseModel): class WorkspaceSwitch(BaseModel): """Pydantic model for switching workspaces.""" + workspace_name: str model_config = ConfigDict(from_attributes=True) @@ -66,6 +70,7 @@ class WorkspaceSwitch(BaseModel): class WorkspaceUpdate(BaseModel): """Pydantic model for workspace updates.""" + workspace_name: Optional[str] = None model_config = ConfigDict(from_attributes=True) @@ -73,6 +78,7 @@ class WorkspaceUpdate(BaseModel): class UserWorkspace(BaseModel): """Pydantic model for user workspace information.""" + user_role: UserRoles workspace_id: int workspace_name: str @@ -83,6 +89,7 @@ class UserWorkspace(BaseModel): class UserCreateWithCode(BaseModel): """Pydantic model for user creation with recovery codes.""" + is_default_workspace: bool = False recovery_codes: List[str] = [] role: UserRoles @@ -94,6 +101,7 @@ class UserCreateWithCode(BaseModel): class WorkspaceInvite(BaseModel): """Pydantic model for inviting users to a workspace.""" + email: EmailStr role: UserRoles = UserRoles.READ_ONLY workspace_name: str @@ -103,6 +111,7 @@ class WorkspaceInvite(BaseModel): class WorkspaceInviteResponse(BaseModel): """Pydantic model for workspace invite response.""" + message: str email: EmailStr workspace_name: str diff --git a/backend/app/workspaces/utils.py b/backend/app/workspaces/utils.py index 9ae7d83..4897303 100644 --- a/backend/app/workspaces/utils.py +++ b/backend/app/workspaces/utils.py @@ -42,7 +42,7 @@ async def create_workspace( ) workspace_db = result.scalar_one_or_none() new_workspace = False - + if workspace_db is None: new_workspace = True workspace_db = WorkspaceDB( @@ -51,9 +51,9 @@ async def create_workspace( created_datetime_utc=datetime.now(timezone.utc), updated_datetime_utc=datetime.now(timezone.utc), workspace_name=user.workspace_name, - is_default=is_default + is_default=is_default, ) - + if api_key: workspace_db.hashed_api_key = get_key_hash(api_key) workspace_db.api_key_first_characters = api_key[:5] diff --git a/backend/migrations/versions/949c9fc0461d_workspace_relationship.py b/backend/migrations/versions/949c9fc0461d_workspace_relationship.py index 020cb27..d04edf2 100644 --- a/backend/migrations/versions/949c9fc0461d_workspace_relationship.py +++ b/backend/migrations/versions/949c9fc0461d_workspace_relationship.py @@ -5,28 +5,32 @@ Create Date: 2025-04-21 21:20:56.282928 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa - +from alembic import op # revision identifiers, used by Alembic. -revision: str = '949c9fc0461d' -down_revision: Union[str, None] = '977e7e73ce06' +revision: str = "949c9fc0461d" +down_revision: Union[str, None] = "977e7e73ce06" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.add_column('experiments_base', sa.Column('workspace_id', sa.Integer(), nullable=False)) - op.create_foreign_key(None, 'experiments_base', 'workspace', ['workspace_id'], ['workspace_id']) + op.add_column( + "experiments_base", sa.Column("workspace_id", sa.Integer(), nullable=False) + ) + op.create_foreign_key( + None, "experiments_base", "workspace", ["workspace_id"], ["workspace_id"] + ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_constraint(None, 'experiments_base', type_='foreignkey') - op.drop_column('experiments_base', 'workspace_id') + op.drop_constraint(None, "experiments_base", type_="foreignkey") + op.drop_column("experiments_base", "workspace_id") # ### end Alembic commands ### diff --git a/backend/migrations/versions/977e7e73ce06_workspace_model.py b/backend/migrations/versions/977e7e73ce06_workspace_model.py index 81ff8f5..bb4ff38 100644 --- a/backend/migrations/versions/977e7e73ce06_workspace_model.py +++ b/backend/migrations/versions/977e7e73ce06_workspace_model.py @@ -5,52 +5,67 @@ Create Date: 2025-04-20 20:17:32.839934 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa - +from alembic import op # revision identifiers, used by Alembic. -revision: str = '977e7e73ce06' -down_revision: Union[str, None] = 'ba1bf29910f5' +revision: str = "977e7e73ce06" +down_revision: Union[str, None] = "ba1bf29910f5" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.create_table('workspace', - sa.Column('api_daily_quota', sa.Integer(), nullable=True), - sa.Column('api_key_first_characters', sa.String(length=5), nullable=True), - sa.Column('api_key_updated_datetime_utc', sa.DateTime(timezone=True), nullable=True), - sa.Column('content_quota', sa.Integer(), nullable=True), - sa.Column('created_datetime_utc', sa.DateTime(timezone=True), nullable=False), - sa.Column('hashed_api_key', sa.String(length=96), nullable=True), - sa.Column('updated_datetime_utc', sa.DateTime(timezone=True), nullable=False), - sa.Column('workspace_id', sa.Integer(), nullable=False), - sa.Column('workspace_name', sa.String(), nullable=False), - sa.Column('is_default', sa.Boolean(), nullable=False), - sa.PrimaryKeyConstraint('workspace_id'), - sa.UniqueConstraint('hashed_api_key'), - sa.UniqueConstraint('workspace_name') + op.create_table( + "workspace", + sa.Column("api_daily_quota", sa.Integer(), nullable=True), + sa.Column("api_key_first_characters", sa.String(length=5), nullable=True), + sa.Column( + "api_key_updated_datetime_utc", sa.DateTime(timezone=True), nullable=True + ), + sa.Column("content_quota", sa.Integer(), nullable=True), + sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False), + sa.Column("hashed_api_key", sa.String(length=96), nullable=True), + sa.Column("updated_datetime_utc", sa.DateTime(timezone=True), nullable=False), + sa.Column("workspace_id", sa.Integer(), nullable=False), + sa.Column("workspace_name", sa.String(), nullable=False), + sa.Column("is_default", sa.Boolean(), nullable=False), + sa.PrimaryKeyConstraint("workspace_id"), + sa.UniqueConstraint("hashed_api_key"), + sa.UniqueConstraint("workspace_name"), ) - op.create_table('user_workspace', - sa.Column('created_datetime_utc', sa.DateTime(timezone=True), nullable=False), - sa.Column('default_workspace', sa.Boolean(), server_default=sa.text('false'), nullable=False), - sa.Column('updated_datetime_utc', sa.DateTime(timezone=True), nullable=False), - sa.Column('user_id', sa.Integer(), nullable=False), - sa.Column('user_role', sa.Enum('ADMIN', 'READ_ONLY', name='userroles', native_enum=False), nullable=False), - sa.Column('workspace_id', sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(['user_id'], ['users.user_id'], ondelete='CASCADE'), - sa.ForeignKeyConstraint(['workspace_id'], ['workspace.workspace_id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('user_id', 'workspace_id') + op.create_table( + "user_workspace", + sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False), + sa.Column( + "default_workspace", + sa.Boolean(), + server_default=sa.text("false"), + nullable=False, + ), + sa.Column("updated_datetime_utc", sa.DateTime(timezone=True), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column( + "user_role", + sa.Enum("ADMIN", "READ_ONLY", name="userroles", native_enum=False), + nullable=False, + ), + sa.Column("workspace_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(["user_id"], ["users.user_id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint( + ["workspace_id"], ["workspace.workspace_id"], ondelete="CASCADE" + ), + sa.PrimaryKeyConstraint("user_id", "workspace_id"), ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('user_workspace') - op.drop_table('workspace') + op.drop_table("user_workspace") + op.drop_table("workspace") # ### end Alembic commands ### diff --git a/frontend/src/app/(protected)/workspace/create/page.tsx b/frontend/src/app/(protected)/workspace/create/page.tsx index 2d57b60..99aff9e 100644 --- a/frontend/src/app/(protected)/workspace/create/page.tsx +++ b/frontend/src/app/(protected)/workspace/create/page.tsx @@ -54,14 +54,14 @@ export default function CreateWorkspacePage() { setIsSubmitting(true); try { const response = await apiCalls.createWorkspace(token, data); - + await switchWorkspace(response.workspace_name); - + toast({ title: "Success", description: `Workspace "${response.workspace_name}" created and activated!`, }); - + router.push("/workspace"); } catch (error: any) { toast({ @@ -73,7 +73,7 @@ export default function CreateWorkspacePage() { setIsSubmitting(false); } }; - + return (
@@ -106,7 +106,7 @@ export default function CreateWorkspacePage() { Choose a descriptive name for your new workspace. - +
); -} \ No newline at end of file +} diff --git a/frontend/src/app/(protected)/workspace/types.ts b/frontend/src/app/(protected)/workspace/types.ts index 65c4108..2d53c31 100644 --- a/frontend/src/app/(protected)/workspace/types.ts +++ b/frontend/src/app/(protected)/workspace/types.ts @@ -1,6 +1,5 @@ export enum UserRoles { ADMIN = "ADMIN", - EDITOR = "EDITOR", VIEWER = "VIEWER", } diff --git a/frontend/src/components/WorkspaceSelector.tsx b/frontend/src/components/WorkspaceSelector.tsx index c958ea1..4692888 100644 --- a/frontend/src/components/WorkspaceSelector.tsx +++ b/frontend/src/components/WorkspaceSelector.tsx @@ -89,9 +89,9 @@ export default function WorkspaceSelector() { {workspace.workspace_name} ))} - + - + Create New Workspace @@ -139,4 +139,4 @@ export default function WorkspaceSelector() {
); -} \ No newline at end of file +} diff --git a/frontend/src/utils/api.ts b/frontend/src/utils/api.ts index 23ca704..41edbdd 100644 --- a/frontend/src/utils/api.ts +++ b/frontend/src/utils/api.ts @@ -144,7 +144,7 @@ const updateWorkspace = async (token: string | null, workspaceId: number, worksp const switchWorkspace = async (token: string | null, workspaceName: string) => { try { - const response = await api.post("/workspace/switch", + const response = await api.post("/workspace/switch", { workspace_name: workspaceName }, { headers: { diff --git a/frontend/src/utils/auth.tsx b/frontend/src/utils/auth.tsx index 814cf98..ee6650d 100644 --- a/frontend/src/utils/auth.tsx +++ b/frontend/src/utils/auth.tsx @@ -70,7 +70,7 @@ const AuthProvider = ({ children }: AuthProviderProps) => { // Fetch current workspace const currentWorkspaceData = await apiCalls.getCurrentWorkspace(token); setCurrentWorkspace(currentWorkspaceData); - + // Fetch all workspaces const workspacesData = await apiCalls.getAllWorkspaces(token); setWorkspaces(workspacesData); @@ -139,7 +139,7 @@ const AuthProvider = ({ children }: AuthProviderProps) => { try { const currentWorkspaceData = await apiCalls.getCurrentWorkspace(access_token); setCurrentWorkspace(currentWorkspaceData); - + const workspacesData = await apiCalls.getAllWorkspaces(access_token); setWorkspaces(workspacesData); } catch (error) { @@ -172,13 +172,13 @@ const AuthProvider = ({ children }: AuthProviderProps) => { try { setIsLoading(true); const response = await apiCalls.switchWorkspace(token, workspaceName); - + localStorage.setItem("ee-token", response.access_token); setToken(response.access_token); - + const currentWorkspaceData = await apiCalls.getCurrentWorkspace(response.access_token); setCurrentWorkspace(currentWorkspaceData); - + return response; } catch (error) { console.error("Error switching workspace:", error); @@ -219,7 +219,7 @@ const AuthProvider = ({ children }: AuthProviderProps) => { try { const currentWorkspaceData = await apiCalls.getCurrentWorkspace(access_token); setCurrentWorkspace(currentWorkspaceData); - + const workspacesData = await apiCalls.getAllWorkspaces(access_token); setWorkspaces(workspacesData); } catch (error) { From 43a2c3851473527358d2156111924432c889a056 Mon Sep 17 00:00:00 2001 From: Jay Prakash <0freerunning@gmail.com> Date: Thu, 1 May 2025 02:04:05 +0530 Subject: [PATCH 06/74] Added workspace removal and user list --- backend/app/auth/dependencies.py | 2 +- backend/app/auth/routers.py | 2 +- backend/app/bayes_ab/models.py | 19 ++- backend/app/bayes_ab/routers.py | 111 ++++++++++++-- backend/app/bayes_ab/schemas.py | 1 + backend/app/contextual_mab/routers.py | 2 +- backend/app/mab/routers.py | 2 +- backend/app/users/exceptions.py | 6 + backend/app/users/models.py | 9 +- backend/app/users/routers.py | 39 ++++- backend/app/workspaces/models.py | 119 +++++++++++++++ backend/app/workspaces/routers.py | 139 +++++++++++++++++- backend/app/workspaces/schemas.py | 14 ++ .../949c9fc0461d_workspace_relationship.py | 36 ----- .../versions/977e7e73ce06_workspace_model.py | 71 --------- .../versions/d9f7a309944e_workspace_model.py | 72 +++++++++ 16 files changed, 501 insertions(+), 143 deletions(-) create mode 100644 backend/app/users/exceptions.py delete mode 100644 backend/migrations/versions/949c9fc0461d_workspace_relationship.py delete mode 100644 backend/migrations/versions/977e7e73ce06_workspace_model.py create mode 100644 backend/migrations/versions/d9f7a309944e_workspace_model.py diff --git a/backend/app/auth/dependencies.py b/backend/app/auth/dependencies.py index 63e1d3a..a4d18f3 100644 --- a/backend/app/auth/dependencies.py +++ b/backend/app/auth/dependencies.py @@ -16,12 +16,12 @@ from ..database import get_async_session from ..users.models import ( UserDB, - UserNotFoundError, get_user_by_api_key, get_user_by_username, save_user_to_db, update_user_verification_status, ) +from ..users.exceptions import UserNotFoundError from ..users.schemas import UserCreate from ..utils import ( generate_key, diff --git a/backend/app/auth/routers.py b/backend/app/auth/routers.py index f9f1934..d447014 100644 --- a/backend/app/auth/routers.py +++ b/backend/app/auth/routers.py @@ -10,11 +10,11 @@ from ..database import get_async_session, get_redis from ..email import EmailService from ..users.models import ( - UserNotFoundError, get_user_by_username, update_user_password, update_user_verification_status, ) +from ..users.exceptions import UserNotFoundError from ..users.schemas import ( EmailVerificationRequest, MessageResponse, diff --git a/backend/app/bayes_ab/models.py b/backend/app/bayes_ab/models.py index adf1745..29c61f7 100644 --- a/backend/app/bayes_ab/models.py +++ b/backend/app/bayes_ab/models.py @@ -52,6 +52,7 @@ def to_dict(self) -> dict: return { "experiment_id": self.experiment_id, "user_id": self.user_id, + "workspace_id": self.workspace_id, "name": self.name, "description": self.description, "sticky_assignment": self.sticky_assignment, @@ -156,6 +157,7 @@ def to_dict(self) -> dict: async def save_bayes_ab_to_db( ab_experiment: BayesianAB, user_id: int, + workspace_id: int, asession: AsyncSession, ) -> BayesianABDB: """ @@ -180,6 +182,7 @@ async def save_bayes_ab_to_db( name=ab_experiment.name, description=ab_experiment.description, user_id=user_id, + workspace_id=workspace_id, is_active=ab_experiment.is_active, created_datetime_utc=datetime.now(timezone.utc), n_trials=0, @@ -201,14 +204,18 @@ async def save_bayes_ab_to_db( async def get_all_bayes_ab_experiments( user_id: int, + workspace_id: int, asession: AsyncSession, ) -> Sequence[BayesianABDB]: """ - Get all the A/B experiments from the database. + Get all the A/B experiments from the database for a specific workspace. """ stmt = ( select(BayesianABDB) - .where(BayesianABDB.user_id == user_id) + .where( + BayesianABDB.user_id == user_id, + BayesianABDB.workspace_id == workspace_id + ) .order_by(BayesianABDB.experiment_id) ) result = await asession.execute(stmt) @@ -218,14 +225,16 @@ async def get_all_bayes_ab_experiments( async def get_bayes_ab_experiment_by_id( experiment_id: int, user_id: int, + workspace_id: int, asession: AsyncSession, ) -> BayesianABDB | None: """ - Get the A/B experiment by id. + Get the A/B experiment by id from a specific workspace. """ stmt = select(BayesianABDB).where( and_( BayesianABDB.user_id == user_id, + BayesianABDB.workspace_id == workspace_id, BayesianABDB.experiment_id == experiment_id, ) ) @@ -236,14 +245,16 @@ async def get_bayes_ab_experiment_by_id( async def delete_bayes_ab_experiment_by_id( experiment_id: int, user_id: int, + workspace_id: int, asession: AsyncSession, ) -> None: """ - Delete the A/B experiment by id. + Delete the A/B experiment by id from a specific workspace. """ stmt = delete(BayesianABDB).where( and_( BayesianABDB.user_id == user_id, + BayesianABDB.workspace_id == workspace_id, BayesianABDB.experiment_id == experiment_id, BayesianABDB.experiment_id == ExperimentBaseDB.experiment_id, ) diff --git a/backend/app/bayes_ab/routers.py b/backend/app/bayes_ab/routers.py index 5b095e5..bde53e4 100644 --- a/backend/app/bayes_ab/routers.py +++ b/backend/app/bayes_ab/routers.py @@ -2,7 +2,7 @@ from typing import Annotated, Optional from uuid import uuid4 -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, status from fastapi.exceptions import HTTPException from sqlalchemy.ext.asyncio import AsyncSession @@ -11,6 +11,8 @@ from ..models import get_notifications_from_db, save_notifications_to_db from ..schemas import NotificationsResponse, Outcome, RewardLikelihood from ..users.models import UserDB +from ..workspaces.models import get_user_default_workspace, get_user_role_in_workspace +from ..workspaces.schemas import UserRoles from .models import ( BayesianABArmDB, BayesianABDB, @@ -44,9 +46,27 @@ async def create_ab_experiment( asession: AsyncSession = Depends(get_async_session), ) -> BayesianABResponse: """ - Create a new experiment. + Create a new experiment in the user's current workspace. """ - bayes_ab = await save_bayes_ab_to_db(experiment, user_db.user_id, asession) + workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) + + user_role = await get_user_role_in_workspace( + asession=asession, user_db=user_db, workspace_db=workspace_db + ) + + if user_role != UserRoles.ADMIN: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Only workspace administrators can create experiments.", + ) + + bayes_ab = await save_bayes_ab_to_db( + experiment, + user_db.user_id, + workspace_db.workspace_id, + asession + ) + notifications = await save_notifications_to_db( experiment_id=bayes_ab.experiment_id, user_id=user_db.user_id, @@ -66,9 +86,15 @@ async def get_bayes_abs( asession: AsyncSession = Depends(get_async_session), ) -> list[BayesianABResponse]: """ - Get details of all experiments. + Get details of all experiments in the user's current workspace. """ - experiments = await get_all_bayes_ab_experiments(user_db.user_id, asession) + workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) + + experiments = await get_all_bayes_ab_experiments( + user_db.user_id, + workspace_db.workspace_id, + asession + ) all_experiments = [] for exp in experiments: @@ -101,8 +127,13 @@ async def get_bayes_ab( """ Get details of experiment with the provided `experiment_id`. """ + workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) + experiment = await get_bayes_ab_experiment_by_id( - experiment_id, user_db.user_id, asession + experiment_id, + user_db.user_id, + workspace_db.workspace_id, + asession ) if experiment is None: @@ -131,14 +162,36 @@ async def delete_bayes_ab( Delete the experiment with the provided `experiment_id`. """ try: + workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) + + user_role = await get_user_role_in_workspace( + asession=asession, user_db=user_db, workspace_db=workspace_db + ) + + if user_role != UserRoles.ADMIN: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Only workspace administrators can delete experiments.", + ) + experiment = await get_bayes_ab_experiment_by_id( - experiment_id, user_db.user_id, asession + experiment_id, + user_db.user_id, + workspace_db.workspace_id, + asession ) if experiment is None: raise HTTPException( status_code=404, detail=f"Experiment with id {experiment_id} not found" ) - await delete_bayes_ab_experiment_by_id(experiment_id, user_db.user_id, asession) + + await delete_bayes_ab_experiment_by_id( + experiment_id, + user_db.user_id, + workspace_db.workspace_id, + asession + ) + return {"message": f"Experiment with id {experiment_id} deleted successfully."} except Exception as e: raise HTTPException(status_code=500, detail=f"Error: {e}") from e @@ -154,8 +207,13 @@ async def draw_arm( """ Get which arm to pull next for provided experiment. """ + workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) + experiment = await get_bayes_ab_experiment_by_id( - experiment_id, user_db.user_id, asession + experiment_id, + user_db.user_id, + workspace_db.workspace_id, + asession ) if experiment is None: @@ -214,9 +272,15 @@ async def save_observation_for_arm( Update the arm with the provided `arm_id` for the given `experiment_id` based on the `outcome`. """ + workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) + # Get and validate experiment experiment, draw = await validate_experiment_and_draw( - experiment_id, draw_id, user_db.user_id, asession + experiment_id=experiment_id, + draw_id=draw_id, + user_id=user_db.user_id, + workspace_id=workspace_db.workspace_id, + asession=asession, ) update_experiment_metadata(experiment) @@ -246,8 +310,13 @@ async def get_outcomes( """ Get the outcomes for the experiment. """ + workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) + experiment = await get_bayes_ab_experiment_by_id( - experiment_id, user_db.user_id, asession + experiment_id, + user_db.user_id, + workspace_db.workspace_id, + asession ) if not experiment: raise HTTPException( @@ -275,9 +344,14 @@ async def update_arms( """ Get the outcomes for the experiment. """ + workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) + # Check experiment params experiment = await get_bayes_ab_experiment_by_id( - experiment_id, user_db.user_id, asession + experiment_id, + user_db.user_id, + workspace_db.workspace_id, + asession ) if not experiment: raise HTTPException( @@ -312,10 +386,19 @@ async def update_arms( async def validate_experiment_and_draw( - experiment_id: int, draw_id: str, user_id: int, asession: AsyncSession + experiment_id: int, + draw_id: str, + user_id: int, + workspace_id: int, + asession: AsyncSession, ) -> tuple[BayesianABDB, BayesianABDrawDB]: """Validate the experiment and draw""" - experiment = await get_bayes_ab_experiment_by_id(experiment_id, user_id, asession) + experiment = await get_bayes_ab_experiment_by_id( + experiment_id, + user_id, + workspace_id, + asession + ) if experiment is None: raise HTTPException( status_code=404, detail=f"Experiment with id {experiment_id} not found" diff --git a/backend/app/bayes_ab/schemas.py b/backend/app/bayes_ab/schemas.py index 35183fa..b0dba26 100644 --- a/backend/app/bayes_ab/schemas.py +++ b/backend/app/bayes_ab/schemas.py @@ -108,6 +108,7 @@ class BayesianABResponse(MultiArmedBanditBase): """ experiment_id: int + workspace_id: int arms: list[BayesABArmResponse] notifications: list[NotificationsResponse] created_datetime_utc: datetime diff --git a/backend/app/contextual_mab/routers.py b/backend/app/contextual_mab/routers.py index 48a9fa2..46f67d7 100644 --- a/backend/app/contextual_mab/routers.py +++ b/backend/app/contextual_mab/routers.py @@ -12,9 +12,9 @@ from ..models import get_notifications_from_db, save_notifications_to_db from ..schemas import ContextType, NotificationsResponse, Outcome, RewardLikelihood from ..users.models import UserDB +from ..utils import setup_logger from ..workspaces.models import get_user_default_workspace, get_user_role_in_workspace from ..workspaces.schemas import UserRoles -from ..utils import setup_logger from .models import ( ContextualArmDB, ContextualBanditDB, diff --git a/backend/app/mab/routers.py b/backend/app/mab/routers.py index 2adf5a2..6adf12a 100644 --- a/backend/app/mab/routers.py +++ b/backend/app/mab/routers.py @@ -11,9 +11,9 @@ from ..models import get_notifications_from_db, save_notifications_to_db from ..schemas import NotificationsResponse, Outcome, RewardLikelihood from ..users.models import UserDB +from ..utils import setup_logger from ..workspaces.models import get_user_default_workspace, get_user_role_in_workspace from ..workspaces.schemas import UserRoles -from ..utils import setup_logger from .models import ( MABArmDB, MABDrawDB, diff --git a/backend/app/users/exceptions.py b/backend/app/users/exceptions.py new file mode 100644 index 0000000..8be0490 --- /dev/null +++ b/backend/app/users/exceptions.py @@ -0,0 +1,6 @@ +class UserNotFoundError(Exception): + """Exception raised when a user is not found in the database.""" + + +class UserAlreadyExistsError(Exception): + """Exception raised when a user already exists in the database.""" \ No newline at end of file diff --git a/backend/app/users/models.py b/backend/app/users/models.py index 2981195..c5912b9 100644 --- a/backend/app/users/models.py +++ b/backend/app/users/models.py @@ -15,18 +15,11 @@ from ..utils import get_key_hash, get_password_salted_hash, get_random_string from ..workspaces.models import UserWorkspaceDB, WorkspaceDB from .schemas import UserCreate, UserCreateWithPassword +from ..users.exceptions import UserAlreadyExistsError, UserNotFoundError PASSWORD_LENGTH = 12 -class UserNotFoundError(Exception): - """Exception raised when a user is not found in the database.""" - - -class UserAlreadyExistsError(Exception): - """Exception raised when a user already exists in the database.""" - - class UserDB(Base): """ SQL Alchemy data model for users diff --git a/backend/app/users/routers.py b/backend/app/users/routers.py index 80cd5e3..443dde3 100644 --- a/backend/app/users/routers.py +++ b/backend/app/users/routers.py @@ -12,11 +12,11 @@ from ..database import get_async_session, get_redis from ..email import EmailService from ..users.models import ( - UserAlreadyExistsError, UserDB, save_user_to_db, update_user_api_key, ) +from ..users.exceptions import UserAlreadyExistsError from ..utils import generate_key, setup_logger, update_api_limits from .schemas import KeyResponse, UserCreate, UserCreateWithPassword, UserRetrieve @@ -45,8 +45,16 @@ async def create_user( """ try: # Import workspace functionality to avoid circular imports - from ..workspaces.models import UserRoles, create_user_workspace_role - from ..workspaces.utils import create_workspace + from ..workspaces.models import ( + UserRoles, + create_user_workspace_role, + get_pending_invitations_by_email, + delete_pending_invitation + ) + from ..workspaces.utils import ( + create_workspace, + get_workspace_by_workspace_id + ) # Create the user new_api_key = generate_key() @@ -69,6 +77,8 @@ async def create_user( user=UserCreate( role=UserRoles.ADMIN, username=user_new.username, + first_name=user_new.first_name, + last_name=user_new.last_name, workspace_name=default_workspace_name, ), is_default=True, @@ -84,6 +94,29 @@ async def create_user( workspace_db=workspace_db, ) + # Check for pending invitations + pending_invitations = await get_pending_invitations_by_email( + asession=asession, email=user_new.username + ) + + # Process pending invitations + for invitation in pending_invitations: + invite_workspace = await get_workspace_by_workspace_id( + asession=asession, workspace_id=invitation.workspace_id + ) + + # Add user to the invited workspace + await create_user_workspace_role( + asession=asession, + is_default_workspace=False, + user_db=user_new, + user_role=invitation.role, + workspace_db=invite_workspace, + ) + + # Delete the invitation + await delete_pending_invitation(asession=asession, invitation=invitation) + # Send verification email token = await generate_verification_token( user_new.user_id, user_new.username, redis diff --git a/backend/app/workspaces/models.py b/backend/app/workspaces/models.py index 9547917..34dec0a 100644 --- a/backend/app/workspaces/models.py +++ b/backend/app/workspaces/models.py @@ -8,6 +8,7 @@ ForeignKey, Integer, String, + and_, case, exists, select, @@ -16,6 +17,8 @@ ) from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.exc import NoResultFound +from ..users.exceptions import UserNotFoundError from ..models import Base, ExperimentBaseDB from ..users.schemas import UserCreate @@ -75,6 +78,10 @@ class WorkspaceDB(Base): "ExperimentBaseDB", back_populates="workspace", cascade="all, delete-orphan" ) + pending_invitations: Mapped[list["PendingInvitationDB"]] = relationship( + "PendingInvitationDB", back_populates="workspace", cascade="all, delete-orphan" + ) + def __repr__(self) -> str: """Define the string representation for the `WorkspaceDB` class.""" return f"" @@ -117,6 +124,118 @@ def __repr__(self) -> str: return f"." +class PendingInvitationDB(Base): + """ORM for managing pending workspace invitations.""" + + __tablename__ = "pending_invitations" + + invitation_id: Mapped[int] = mapped_column(Integer, primary_key=True) + email: Mapped[str] = mapped_column(String, nullable=False) + workspace_id: Mapped[int] = mapped_column( + Integer, ForeignKey("workspace.workspace_id", ondelete="CASCADE"), nullable=False + ) + role: Mapped[UserRoles] = mapped_column( + Enum(UserRoles, native_enum=False), nullable=False + ) + inviter_id: Mapped[int] = mapped_column( + Integer, ForeignKey("users.user_id"), nullable=False + ) + created_datetime_utc: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False + ) + + workspace: Mapped["WorkspaceDB"] = relationship( + "WorkspaceDB", back_populates="pending_invitations" + ) + + def __repr__(self) -> str: + return f"" + + +async def get_users_in_workspace( + *, asession: AsyncSession, workspace_db: WorkspaceDB +) -> Sequence[UserWorkspaceDB]: + """Get all users in a workspace with their roles.""" + stmt = ( + select(UserWorkspaceDB) + .where(UserWorkspaceDB.workspace_id == workspace_db.workspace_id) + ) + result = await asession.execute(stmt) + return result.unique().scalars().all() + +async def get_user_by_user_id( + user_id: int, asession: AsyncSession +) -> "UserDB": + """Get a user by user ID.""" + stmt = select(UserDB).where(UserDB.user_id == user_id) + result = await asession.execute(stmt) + try: + return result.scalar_one() + except NoResultFound as e: + raise UserNotFoundError(f"User with ID {user_id} not found") from e + +async def remove_user_from_workspace( + *, asession: AsyncSession, user_db: "UserDB", workspace_db: WorkspaceDB +) -> None: + """Remove a user from a workspace.""" + # Check if user exists in workspace + stmt = select(UserWorkspaceDB).where( + and_( + UserWorkspaceDB.user_id == user_db.user_id, + UserWorkspaceDB.workspace_id == workspace_db.workspace_id, + ) + ) + result = await asession.execute(stmt) + user_workspace = result.scalar_one_or_none() + + if not user_workspace: + raise UserNotFoundInWorkspaceError( + f"User '{user_db.username}' not found in workspace '{workspace_db.workspace_name}'." + ) + + # Delete the relationship + await asession.delete(user_workspace) + await asession.commit() + +async def create_pending_invitation( + *, + asession: AsyncSession, + email: str, + workspace_db: WorkspaceDB, + role: UserRoles, + inviter_id: int, +) -> PendingInvitationDB: + """Create a pending invitation.""" + invitation = PendingInvitationDB( + email=email, + workspace_id=workspace_db.workspace_id, + role=role, + inviter_id=inviter_id, + created_datetime_utc=datetime.now(timezone.utc), + ) + + asession.add(invitation) + await asession.commit() + await asession.refresh(invitation) + + return invitation + +async def get_pending_invitations_by_email( + *, asession: AsyncSession, email: str +) -> Sequence[PendingInvitationDB]: + """Get all pending invitations for an email.""" + stmt = select(PendingInvitationDB).where(PendingInvitationDB.email == email) + result = await asession.execute(stmt) + return result.scalars().all() + +async def delete_pending_invitation( + *, asession: AsyncSession, invitation: PendingInvitationDB +) -> None: + """Delete a pending invitation.""" + await asession.delete(invitation) + await asession.commit() + + async def check_if_user_has_default_workspace( *, asession: AsyncSession, user_db: "UserDB" ) -> bool | None: diff --git a/backend/app/workspaces/routers.py b/backend/app/workspaces/routers.py index 2cb8800..2cbbf63 100644 --- a/backend/app/workspaces/routers.py +++ b/backend/app/workspaces/routers.py @@ -6,9 +6,12 @@ from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession +from ..users.exceptions import UserNotFoundError + from ..auth.dependencies import ( create_access_token, get_current_user, + get_verified_user, ) from ..auth.schemas import AuthenticationDetails from ..config import DEFAULT_API_QUOTA, DEFAULT_EXPERIMENTS_QUOTA @@ -16,17 +19,21 @@ from ..email import EmailService from ..users.models import ( UserDB, - UserNotFoundError, get_user_by_username, ) -from ..users.schemas import UserCreate +from ..users.schemas import MessageResponse, UserCreate from ..utils import generate_key, setup_logger from .models import ( + UserNotFoundInWorkspaceError, add_existing_user_to_workspace, check_if_user_has_default_workspace, + create_pending_invitation, + get_user_by_user_id, get_user_default_workspace, get_user_role_in_workspace, get_user_workspaces, + get_users_in_workspace, + remove_user_from_workspace, update_user_default_workspace, ) from .schemas import ( @@ -38,6 +45,7 @@ WorkspaceRetrieve, WorkspaceSwitch, WorkspaceUpdate, + WorkspaceUserResponse, ) from .utils import ( WorkspaceNotFoundError, @@ -93,6 +101,8 @@ async def create_workspace_endpoint( user=UserCreate( role=UserRoles.ADMIN, username=calling_user_db.username, + first_name=calling_user_db.first_name, + last_name=calling_user_db.last_name, workspace_name=workspace.workspace_name, ), api_key=api_key, @@ -106,6 +116,8 @@ async def create_workspace_endpoint( is_default_workspace=False, # Don't make it default automatically role=UserRoles.ADMIN, username=calling_user_db.username, + first_name=calling_user_db.first_name, + last_name=calling_user_db.last_name, workspace_name=workspace_db.workspace_name, ), workspace_db=workspace_db, @@ -428,6 +440,8 @@ async def invite_user_to_workspace( user=UserCreate( role=invite.role, username=invite.email, + first_name=invited_user.first_name, + last_name=invited_user.last_name, workspace_name=invite.workspace_name, ), workspace_db=workspace_db, @@ -451,7 +465,16 @@ async def invite_user_to_workspace( ) except UserNotFoundError: - # User doesn't exist, send invitation to create account + # User doesn't exist, create pending invitation + await create_pending_invitation( + asession=asession, + email=invite.email, + workspace_db=workspace_db, + role=invite.role, + inviter_id=calling_user_db.user_id, + ) + + # Send invitation email background_tasks.add_task( email_service.send_workspace_invitation_email, invite.email, @@ -479,3 +502,113 @@ async def invite_user_to_workspace( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Error inviting user to workspace.", ) from e + + +@router.get("/{workspace_id}/users", response_model=list[WorkspaceUserResponse]) +async def get_workspace_users( + workspace_id: int, + calling_user_db: Annotated[UserDB, Depends(get_verified_user)], + asession: AsyncSession = Depends(get_async_session), +) -> list[WorkspaceUserResponse]: + """Get all users in a workspace.""" + try: + workspace_db = await get_workspace_by_workspace_id( + asession=asession, workspace_id=workspace_id + ) + + user_role = await get_user_role_in_workspace( + asession=asession, user_db=calling_user_db, workspace_db=workspace_db + ) + + if user_role is None: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"User does not have access to workspace with ID {workspace_id}.", + ) + + user_workspaces = await get_users_in_workspace( + asession=asession, workspace_db=workspace_db + ) + + result = [] + for uw in user_workspaces: + user = await get_user_by_user_id(uw.user_id, asession) + result.append( + WorkspaceUserResponse( + user_id=user.user_id, + username=user.username, + first_name=user.first_name, + last_name=user.last_name, + role=uw.user_role, + is_default_workspace=uw.default_workspace, + created_datetime_utc=uw.created_datetime_utc, + ) + ) + + return result + except WorkspaceNotFoundError as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Workspace with ID {workspace_id} not found.", + ) from e + +@router.delete("/{workspace_id}/users/{username}", response_model=MessageResponse) +async def remove_user_from_workspace_endpoint( + workspace_id: int, + username: str, + calling_user_db: Annotated[UserDB, Depends(get_verified_user)], + asession: AsyncSession = Depends(get_async_session), +) -> MessageResponse: + """Remove a user from a workspace.""" + try: + workspace_db = await get_workspace_by_workspace_id( + asession=asession, workspace_id=workspace_id + ) + + caller_role = await get_user_role_in_workspace( + asession=asession, user_db=calling_user_db, workspace_db=workspace_db + ) + + if caller_role != UserRoles.ADMIN: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Only workspace administrators can remove users.", + ) + + if workspace_db.is_default: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Cannot remove users from default workspaces.", + ) + + if username == calling_user_db.username: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You cannot remove yourself from a workspace.", + ) + + try: + user_to_remove = await get_user_by_username(username=username, asession=asession) + except UserNotFoundError: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"User with username '{username}' not found.", + ) + + try: + await remove_user_from_workspace( + asession=asession, user_db=user_to_remove, workspace_db=workspace_db + ) + return MessageResponse( + message=f"User '{username}' successfully removed from workspace '{workspace_db.workspace_name}'." + ) + except UserNotFoundInWorkspaceError: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"User '{username}' is not a member of workspace '{workspace_db.workspace_name}'.", + ) + except WorkspaceNotFoundError: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Workspace with ID {workspace_id} not found.", + ) diff --git a/backend/app/workspaces/schemas.py b/backend/app/workspaces/schemas.py index 512a78b..2e08fcd 100644 --- a/backend/app/workspaces/schemas.py +++ b/backend/app/workspaces/schemas.py @@ -118,3 +118,17 @@ class WorkspaceInviteResponse(BaseModel): user_exists: bool model_config = ConfigDict(from_attributes=True) + + +class WorkspaceUserResponse(BaseModel): + """Pydantic model for workspace user information.""" + + user_id: int + username: str + first_name: str + last_name: str + role: UserRoles + is_default_workspace: bool + created_datetime_utc: datetime + + model_config = ConfigDict(from_attributes=True) diff --git a/backend/migrations/versions/949c9fc0461d_workspace_relationship.py b/backend/migrations/versions/949c9fc0461d_workspace_relationship.py deleted file mode 100644 index d04edf2..0000000 --- a/backend/migrations/versions/949c9fc0461d_workspace_relationship.py +++ /dev/null @@ -1,36 +0,0 @@ -"""Workspace relationship - -Revision ID: 949c9fc0461d -Revises: 977e7e73ce06 -Create Date: 2025-04-21 21:20:56.282928 - -""" - -from typing import Sequence, Union - -import sqlalchemy as sa -from alembic import op - -# revision identifiers, used by Alembic. -revision: str = "949c9fc0461d" -down_revision: Union[str, None] = "977e7e73ce06" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.add_column( - "experiments_base", sa.Column("workspace_id", sa.Integer(), nullable=False) - ) - op.create_foreign_key( - None, "experiments_base", "workspace", ["workspace_id"], ["workspace_id"] - ) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_constraint(None, "experiments_base", type_="foreignkey") - op.drop_column("experiments_base", "workspace_id") - # ### end Alembic commands ### diff --git a/backend/migrations/versions/977e7e73ce06_workspace_model.py b/backend/migrations/versions/977e7e73ce06_workspace_model.py deleted file mode 100644 index bb4ff38..0000000 --- a/backend/migrations/versions/977e7e73ce06_workspace_model.py +++ /dev/null @@ -1,71 +0,0 @@ -"""Workspace model - -Revision ID: 977e7e73ce06 -Revises: ba1bf29910f5 -Create Date: 2025-04-20 20:17:32.839934 - -""" - -from typing import Sequence, Union - -import sqlalchemy as sa -from alembic import op - -# revision identifiers, used by Alembic. -revision: str = "977e7e73ce06" -down_revision: Union[str, None] = "ba1bf29910f5" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table( - "workspace", - sa.Column("api_daily_quota", sa.Integer(), nullable=True), - sa.Column("api_key_first_characters", sa.String(length=5), nullable=True), - sa.Column( - "api_key_updated_datetime_utc", sa.DateTime(timezone=True), nullable=True - ), - sa.Column("content_quota", sa.Integer(), nullable=True), - sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False), - sa.Column("hashed_api_key", sa.String(length=96), nullable=True), - sa.Column("updated_datetime_utc", sa.DateTime(timezone=True), nullable=False), - sa.Column("workspace_id", sa.Integer(), nullable=False), - sa.Column("workspace_name", sa.String(), nullable=False), - sa.Column("is_default", sa.Boolean(), nullable=False), - sa.PrimaryKeyConstraint("workspace_id"), - sa.UniqueConstraint("hashed_api_key"), - sa.UniqueConstraint("workspace_name"), - ) - op.create_table( - "user_workspace", - sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False), - sa.Column( - "default_workspace", - sa.Boolean(), - server_default=sa.text("false"), - nullable=False, - ), - sa.Column("updated_datetime_utc", sa.DateTime(timezone=True), nullable=False), - sa.Column("user_id", sa.Integer(), nullable=False), - sa.Column( - "user_role", - sa.Enum("ADMIN", "READ_ONLY", name="userroles", native_enum=False), - nullable=False, - ), - sa.Column("workspace_id", sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(["user_id"], ["users.user_id"], ondelete="CASCADE"), - sa.ForeignKeyConstraint( - ["workspace_id"], ["workspace.workspace_id"], ondelete="CASCADE" - ), - sa.PrimaryKeyConstraint("user_id", "workspace_id"), - ) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_table("user_workspace") - op.drop_table("workspace") - # ### end Alembic commands ### diff --git a/backend/migrations/versions/d9f7a309944e_workspace_model.py b/backend/migrations/versions/d9f7a309944e_workspace_model.py new file mode 100644 index 0000000..1482764 --- /dev/null +++ b/backend/migrations/versions/d9f7a309944e_workspace_model.py @@ -0,0 +1,72 @@ +"""Workspace model + +Revision ID: d9f7a309944e +Revises: 5c15463fda65 +Create Date: 2025-04-30 23:23:21.122138 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'd9f7a309944e' +down_revision: Union[str, None] = '5c15463fda65' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('workspace', + sa.Column('api_daily_quota', sa.Integer(), nullable=True), + sa.Column('api_key_first_characters', sa.String(length=5), nullable=True), + sa.Column('api_key_updated_datetime_utc', sa.DateTime(timezone=True), nullable=True), + sa.Column('content_quota', sa.Integer(), nullable=True), + sa.Column('created_datetime_utc', sa.DateTime(timezone=True), nullable=False), + sa.Column('hashed_api_key', sa.String(length=96), nullable=True), + sa.Column('updated_datetime_utc', sa.DateTime(timezone=True), nullable=False), + sa.Column('workspace_id', sa.Integer(), nullable=False), + sa.Column('workspace_name', sa.String(), nullable=False), + sa.Column('is_default', sa.Boolean(), nullable=False), + sa.PrimaryKeyConstraint('workspace_id'), + sa.UniqueConstraint('hashed_api_key'), + sa.UniqueConstraint('workspace_name') + ) + op.create_table('pending_invitations', + sa.Column('invitation_id', sa.Integer(), nullable=False), + sa.Column('email', sa.String(), nullable=False), + sa.Column('workspace_id', sa.Integer(), nullable=False), + sa.Column('role', sa.Enum('ADMIN', 'READ_ONLY', name='userroles', native_enum=False), nullable=False), + sa.Column('inviter_id', sa.Integer(), nullable=False), + sa.Column('created_datetime_utc', sa.DateTime(timezone=True), nullable=False), + sa.ForeignKeyConstraint(['inviter_id'], ['users.user_id'], ), + sa.ForeignKeyConstraint(['workspace_id'], ['workspace.workspace_id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('invitation_id') + ) + op.create_table('user_workspace', + sa.Column('created_datetime_utc', sa.DateTime(timezone=True), nullable=False), + sa.Column('default_workspace', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('updated_datetime_utc', sa.DateTime(timezone=True), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=False), + sa.Column('user_role', sa.Enum('ADMIN', 'READ_ONLY', name='userroles', native_enum=False), nullable=False), + sa.Column('workspace_id', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['user_id'], ['users.user_id'], ondelete='CASCADE'), + sa.ForeignKeyConstraint(['workspace_id'], ['workspace.workspace_id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('user_id', 'workspace_id') + ) + op.add_column('experiments_base', sa.Column('workspace_id', sa.Integer(), nullable=False)) + op.create_foreign_key(None, 'experiments_base', 'workspace', ['workspace_id'], ['workspace_id']) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, 'experiments_base', type_='foreignkey') + op.drop_column('experiments_base', 'workspace_id') + op.drop_table('user_workspace') + op.drop_table('pending_invitations') + op.drop_table('workspace') + # ### end Alembic commands ### From 64768d6af15ca7bb6aae7444f22d9ce1e76a7dcd Mon Sep 17 00:00:00 2001 From: Jay Prakash <0freerunning@gmail.com> Date: Thu, 1 May 2025 13:08:56 +0300 Subject: [PATCH 07/74] New frontend --- backend/app/users/routers.py | 31 - backend/app/workspaces/models.py | 1 + .../src/app/(protected)/integration/api.ts | 25 - .../integration/components/ApiKeyDisplay.tsx | 135 ---- .../src/app/(protected)/integration/page.tsx | 26 - .../app/(protected)/workspace/create/page.tsx | 131 ---- .../app/(protected)/workspace/invite/page.tsx | 193 ------ .../src/app/(protected)/workspace/page.tsx | 116 ---- .../src/app/(protected)/workspace/types.ts | 50 -- .../workspaces/[workspaceId]/page.tsx | 589 ++++++++++++++++++ .../[workspaceId]/users/invite/page.tsx | 244 ++++++++ .../src/app/(protected)/workspaces/page.tsx | 243 ++++++++ .../src/app/(protected)/workspaces/types.ts | 22 + frontend/src/components/WorkspaceSelector.tsx | 142 ----- frontend/src/components/app-sidebar.tsx | 121 +--- .../components/create-workspace-dialog.tsx | 152 +++++ frontend/src/components/ui/alert-dialog.tsx | 143 +++++ frontend/src/components/ui/tabs.tsx | 55 ++ .../src/components/workspace-switcher.tsx | 65 +- frontend/src/utils/api.ts | 186 ++++-- frontend/src/utils/auth.tsx | 183 ++++-- 21 files changed, 1778 insertions(+), 1075 deletions(-) delete mode 100644 frontend/src/app/(protected)/integration/api.ts delete mode 100644 frontend/src/app/(protected)/integration/components/ApiKeyDisplay.tsx delete mode 100644 frontend/src/app/(protected)/integration/page.tsx delete mode 100644 frontend/src/app/(protected)/workspace/create/page.tsx delete mode 100644 frontend/src/app/(protected)/workspace/invite/page.tsx delete mode 100644 frontend/src/app/(protected)/workspace/page.tsx delete mode 100644 frontend/src/app/(protected)/workspace/types.ts create mode 100644 frontend/src/app/(protected)/workspaces/[workspaceId]/page.tsx create mode 100644 frontend/src/app/(protected)/workspaces/[workspaceId]/users/invite/page.tsx create mode 100644 frontend/src/app/(protected)/workspaces/page.tsx create mode 100644 frontend/src/app/(protected)/workspaces/types.ts delete mode 100644 frontend/src/components/WorkspaceSelector.tsx create mode 100644 frontend/src/components/create-workspace-dialog.tsx create mode 100644 frontend/src/components/ui/alert-dialog.tsx create mode 100644 frontend/src/components/ui/tabs.tsx diff --git a/backend/app/users/routers.py b/backend/app/users/routers.py index 443dde3..5f05144 100644 --- a/backend/app/users/routers.py +++ b/backend/app/users/routers.py @@ -153,34 +153,3 @@ async def get_user( return UserRetrieve.model_validate(user_db) - -@router.put("/rotate-key", response_model=KeyResponse) -async def get_new_api_key( - user_db: Annotated[UserDB, Depends(get_verified_user)], - asession: AsyncSession = Depends(get_async_session), -) -> KeyResponse | None: - """ - Generate a new API key for the requester's account. Takes a user object, - generates a new key, replaces the old one in the database, and returns - a user object with the new key. - """ - - new_api_key = generate_key() - - try: - # this is neccesarry to attach the user_db to the session - asession.add(user_db) - await update_user_api_key( - user_db=user_db, - new_api_key=new_api_key, - asession=asession, - ) - return KeyResponse( - username=user_db.username, - new_api_key=new_api_key, - ) - except SQLAlchemyError as e: - logger.error(f"Error updating user api key: {e}") - raise HTTPException( - status_code=500, detail="Error updating user api key" - ) from e diff --git a/backend/app/workspaces/models.py b/backend/app/workspaces/models.py index 34dec0a..23d601c 100644 --- a/backend/app/workspaces/models.py +++ b/backend/app/workspaces/models.py @@ -167,6 +167,7 @@ async def get_user_by_user_id( user_id: int, asession: AsyncSession ) -> "UserDB": """Get a user by user ID.""" + from ..users.models import UserDB stmt = select(UserDB).where(UserDB.user_id == user_id) result = await asession.execute(stmt) try: diff --git a/frontend/src/app/(protected)/integration/api.ts b/frontend/src/app/(protected)/integration/api.ts deleted file mode 100644 index 47fe784..0000000 --- a/frontend/src/app/(protected)/integration/api.ts +++ /dev/null @@ -1,25 +0,0 @@ -import api from "@/utils/api"; - -const getUser = async (token: string | null) => { - const response = await api.get("/user/", { - headers: { - Authorization: `Bearer ${token}`, - }, - }); - return response.data; -}; - -const rotateAPIKey = async (token: string | null) => { - const response = await api.put( - "/user/rotate-key", - {}, - { - headers: { - Authorization: `Bearer ${token}`, - }, - } - ); - return response.data; -}; - -export { getUser, rotateAPIKey }; diff --git a/frontend/src/app/(protected)/integration/components/ApiKeyDisplay.tsx b/frontend/src/app/(protected)/integration/components/ApiKeyDisplay.tsx deleted file mode 100644 index 5bc03da..0000000 --- a/frontend/src/app/(protected)/integration/components/ApiKeyDisplay.tsx +++ /dev/null @@ -1,135 +0,0 @@ -"use client"; - -import { useState } from "react"; -import { Button } from "@/components/ui/button"; -import { - Dialog, - DialogContent, - DialogDescription, - DialogTitle, -} from "@/components/ui/dialog"; -import { KeyRound, Copy } from "lucide-react"; -import { useToast } from "@/hooks/use-toast"; -import { useAuth } from "@/utils/auth"; -import { useEffect } from "react"; -import Hourglass from "@/components/Hourglass"; -import { getUser, rotateAPIKey } from "../api"; - -export function ApiKeyDisplay() { - const { token } = useAuth(); - - const [apiKey, setApiKey] = useState(""); - const [isModalOpen, setIsModalOpen] = useState(false); - const [newKey, setNewKey] = useState(""); - const [isRefreshing, setIsRefreshing] = useState(false); - const [isLoading, setIsLoading] = useState(false); - - const { toast } = useToast(); - - useEffect(() => { - setIsLoading(true); - getUser(token) - .then((data) => { - setApiKey(data.api_key_first_characters); - }) - .catch((error: Error) => { - console.log(error); - }) - .finally(() => { - setIsLoading(false); - }); - }, [token]); - - const handleGenerateKey = async () => { - setIsRefreshing(true); - rotateAPIKey(token) - .then((data) => { - setNewKey(data.new_api_key); - setIsModalOpen(true); - }) - .catch((error: Error) => { - console.log(error); - toast({ - title: "Error", - description: "Failed to generate new API key", - variant: "destructive", - }); - }) - .finally(() => { - setIsRefreshing(false); - }); - }; - - const handleCopyKey = async () => { - try { - await navigator.clipboard.writeText(newKey); - toast({ - title: "Copied!", - description: "API key copied to clipboard", - }); - } catch (error) { - toast({ - title: "Error", - description: "Failed to copy API key", - variant: "destructive", - }); - } - }; - - const handleConfirm = () => { - setApiKey(newKey); - setIsModalOpen(false); - toast({ - title: "Success", - description: "API key has been updated", - }); - }; - - return isLoading ? ( -
-
- - Loading... -
-
- ) : ( -
-
-
- - - {apiKey.slice(0, 5)} - {"•".repeat(27)} - -
- -
- - - - New API Key Generated - - Make sure to copy your new API key. You won't be able to see it - again! - -
-
- {newKey} -
-
- - -
-
-
-
-
- ); -} diff --git a/frontend/src/app/(protected)/integration/page.tsx b/frontend/src/app/(protected)/integration/page.tsx deleted file mode 100644 index 6922061..0000000 --- a/frontend/src/app/(protected)/integration/page.tsx +++ /dev/null @@ -1,26 +0,0 @@ -import { - Card, - CardContent, - CardDescription, - CardHeader, - CardTitle, -} from "@/components/ui/card"; -import { ApiKeyDisplay } from "./components/ApiKeyDisplay"; - -export default function ApiKeyPage() { - return ( -
- - - API Key Management - - View and manage your API keys securely - - - - - - -
- ); -} diff --git a/frontend/src/app/(protected)/workspace/create/page.tsx b/frontend/src/app/(protected)/workspace/create/page.tsx deleted file mode 100644 index 99aff9e..0000000 --- a/frontend/src/app/(protected)/workspace/create/page.tsx +++ /dev/null @@ -1,131 +0,0 @@ -"use client"; - -import { zodResolver } from "@hookform/resolvers/zod"; -import { useForm } from "react-hook-form"; -import { z } from "zod"; -import { Button } from "@/components/catalyst/button"; -import { Input } from "@/components/catalyst/input"; -import { useAuth } from "@/utils/auth"; -import { apiCalls } from "@/utils/api"; -import { useToast } from "@/hooks/use-toast"; -import { useRouter } from "next/navigation"; -import { useState } from "react"; -import { - Fieldset, - Field, - FieldGroup, - Label, - Description, -} from "@/components/catalyst/fieldset"; -import { Card, CardContent, CardHeader, CardTitle, CardDescription } from "@/components/ui/card"; -import { BuildingOfficeIcon } from "@heroicons/react/20/solid"; - -const formSchema = z.object({ - workspace_name: z.string().min(3, { - message: "Workspace name must be at least 3 characters", - }), -}); - -type FormValues = z.infer; - -export default function CreateWorkspacePage() { - const { token, switchWorkspace } = useAuth(); - const { toast } = useToast(); - const router = useRouter(); - const [isSubmitting, setIsSubmitting] = useState(false); - - const form = useForm({ - resolver: zodResolver(formSchema), - defaultValues: { - workspace_name: "", - }, - }); - - const onSubmit = async (data: FormValues) => { - if (!token) { - toast({ - title: "Error", - description: "You must be logged in to create a workspace", - variant: "destructive", - }); - return; - } - - setIsSubmitting(true); - try { - const response = await apiCalls.createWorkspace(token, data); - - await switchWorkspace(response.workspace_name); - - toast({ - title: "Success", - description: `Workspace "${response.workspace_name}" created and activated!`, - }); - - router.push("/workspace"); - } catch (error: any) { - toast({ - title: "Error", - description: error.message || "Failed to create workspace", - variant: "destructive", - }); - } finally { - setIsSubmitting(false); - } - }; - - return ( -
- - -
- -
- Create new workspace - - Create a new workspace to organize your experiments and team members - -
-
-
- -
-
- - - - - {form.formState.errors.workspace_name && ( -

- {form.formState.errors.workspace_name.message} -

- )} - Choose a descriptive name for your new workspace. -
-
- -
- - -
-
-
-
-
-
- ); -} diff --git a/frontend/src/app/(protected)/workspace/invite/page.tsx b/frontend/src/app/(protected)/workspace/invite/page.tsx deleted file mode 100644 index f610ffc..0000000 --- a/frontend/src/app/(protected)/workspace/invite/page.tsx +++ /dev/null @@ -1,193 +0,0 @@ -"use client"; - -import { useState } from "react"; -import { useAuth } from "@/utils/auth"; -import { apiCalls } from "@/utils/api"; -import { useToast } from "@/hooks/use-toast"; -import { z } from "zod"; -import { zodResolver } from "@hookform/resolvers/zod"; -import { useForm } from "react-hook-form"; -import { Button } from "@/components/catalyst/button"; -import { Heading } from "@/components/catalyst/heading"; -import { Input } from "@/components/catalyst/input"; -import { - Card, - CardContent, - CardDescription, - CardFooter, - CardHeader, - CardTitle, -} from "@/components/ui/card"; -import { - Fieldset, - Field, - FieldGroup, - Label, - Description, -} from "@/components/catalyst/fieldset"; -import { Radio, RadioField, RadioGroup } from "@/components/catalyst/radio"; -import { Badge } from "@/components/ui/badge"; -import { EnvelopeIcon, UserPlusIcon } from "@heroicons/react/20/solid"; - -const inviteSchema = z.object({ - email: z.string().email({ - message: "Please enter a valid email address", - }), - role: z.enum(["ADMIN", "EDITOR", "VIEWER"], { - required_error: "Please select a role", - }), -}); - -type InviteFormValues = z.infer; - -export default function InviteUsersPage() { - const { token, currentWorkspace } = useAuth(); - const { toast } = useToast(); - const [isSubmitting, setIsSubmitting] = useState(false); - const [invitedUsers, setInvitedUsers] = useState<{ email: string; role: string; exists: boolean }[]>([]); - - const { - register, - handleSubmit, - reset, - formState: { errors }, - setValue, - watch, - } = useForm({ - resolver: zodResolver(inviteSchema), - defaultValues: { - email: "", - role: "VIEWER", - }, - }); - - const roleValue = watch("role"); - - const onSubmit = async (data: InviteFormValues) => { - if (!token || !currentWorkspace) { - toast({ - title: "Error", - description: "You must be logged in and have a workspace selected", - variant: "destructive", - }); - return; - } - - setIsSubmitting(true); - try { - const response = await apiCalls.inviteUserToWorkspace(token, { - email: data.email, - role: data.role, - workspace_name: currentWorkspace.workspace_name, - }); - - // Add to invited users list - setInvitedUsers([ - ...invitedUsers, - { - email: data.email, - role: data.role, - exists: response.user_exists, - }, - ]); - - // Reset form - reset(); - - toast({ - title: "Success", - description: `Invitation sent to ${data.email}`, - }); - } catch (error: any) { - toast({ - title: "Error", - description: error.message || "Failed to send invitation", - variant: "destructive", - }); - } finally { - setIsSubmitting(false); - } - }; - - return ( -
-
- Invite Team Members -
- -
-
- - - Send Invitation - - Invite users to join your workspace: {currentWorkspace?.workspace_name} - - - -
-
- - - - - {errors.email && ( -

- {errors.email.message} -

- )} - Enter the email address of the person you want to invite. -
-
- - - - setValue("role", value as "ADMIN" | "VIEWER")} - > - - - - - Can manage workspace settings, invite members, and has full access to all resources. - - - - - - - Can only view resources, but cannot edit them or change any settings. - - - - {errors.role && ( -

- {errors.role.message} -

- )} -
-
- -
- -
-
-
-
-
-
-
- ); -} diff --git a/frontend/src/app/(protected)/workspace/page.tsx b/frontend/src/app/(protected)/workspace/page.tsx deleted file mode 100644 index 70b4d4c..0000000 --- a/frontend/src/app/(protected)/workspace/page.tsx +++ /dev/null @@ -1,116 +0,0 @@ -"use client"; - -import { useEffect } from "react"; -import { useRouter } from "next/navigation"; -import { useAuth } from "@/utils/auth"; -import { Heading } from "@/components/catalyst/heading"; -import { Card, CardContent, CardHeader, CardTitle, CardDescription } from "@/components/ui/card"; -import { Plus, Users, Settings, Key } from "lucide-react"; -import { Button } from "@/components/catalyst/button"; - -export default function WorkspacePage() { - const { currentWorkspace } = useAuth(); - const router = useRouter(); - - if (!currentWorkspace) { - return ( -
- - - - Something went wrong. Please try again later. - - - -
- ); - } - console.log("Current Workspace:", currentWorkspace); - - return ( -
-
- {currentWorkspace.workspace_name} -
- -
- - - Workspace Information - - -
-
-
Name:
-
{currentWorkspace.workspace_name}
-
-
-
API Quota:
-
{currentWorkspace.api_daily_quota.toLocaleString()} calls/day
-
-
-
Experiment Quota:
-
{currentWorkspace.content_quota.toLocaleString()} experiments
-
-
-
Created:
-
- {new Date(currentWorkspace.created_datetime_utc).toLocaleDateString()} -
-
-
-
-
- - - - API Key - - -
-
- - - {currentWorkspace.api_key_first_characters} - {"•".repeat(27)} - -
- -
-

- Use this API key to authenticate your API requests. Keep it secret and secure. -

-
-
-
- -
- router.push('/workspace/invite')}> - -
- Invite Team Members - - Invite colleagues to join your workspace - -
- -
-
- - router.push('/workspace/create')}> - -
- Create New Workspace - - Create a new workspace for different projects - -
- -
-
-
-
- ); -} diff --git a/frontend/src/app/(protected)/workspace/types.ts b/frontend/src/app/(protected)/workspace/types.ts deleted file mode 100644 index 2d53c31..0000000 --- a/frontend/src/app/(protected)/workspace/types.ts +++ /dev/null @@ -1,50 +0,0 @@ -export enum UserRoles { - ADMIN = "ADMIN", - VIEWER = "VIEWER", -} - -export interface Workspace { - workspace_id: number; - workspace_name: string; - api_key_first_characters: string; - api_key_updated_datetime_utc: string; - api_daily_quota: number; - content_quota: number; - created_datetime_utc: string; - updated_datetime_utc: string; - is_default: boolean; -} - -export interface WorkspaceCreate { - workspace_name: string; - api_daily_quota?: number; - content_quota?: number; -} - -export interface WorkspaceUpdate { - workspace_name?: string; - api_daily_quota?: number; - content_quota?: number; -} - -export interface WorkspaceKeyResponse { - new_api_key: string; - workspace_name: string; -} - -export interface WorkspaceInvite { - email: string; - role: UserRoles; - workspace_name: string; -} - -export interface WorkspaceInviteResponse { - message: string; - email: string; - workspace_name: string; - user_exists: boolean; -} - -export interface WorkspaceSwitch { - workspace_name: string; -} diff --git a/frontend/src/app/(protected)/workspaces/[workspaceId]/page.tsx b/frontend/src/app/(protected)/workspaces/[workspaceId]/page.tsx new file mode 100644 index 0000000..a3d2a96 --- /dev/null +++ b/frontend/src/app/(protected)/workspaces/[workspaceId]/page.tsx @@ -0,0 +1,589 @@ +// Path: frontend/src/app/(protected)/workspaces/[workspaceId]/page.tsx +"use client"; + +import { useState, useEffect } from "react"; +import { useParams, useRouter } from "next/navigation"; +import { useAuth } from "@/utils/auth"; +import { apiCalls } from "@/utils/api"; +import { useToast } from "@/hooks/use-toast"; + +import { Building, ChevronLeftIcon, Users, Key, Copy } from "lucide-react"; +import { Button } from "@/components/ui/button"; +import { + Card, + CardContent, + CardDescription, + CardFooter, + CardHeader, + CardTitle, +} from "@/components/ui/card"; +import { + Breadcrumb, + BreadcrumbItem, + BreadcrumbLink, + BreadcrumbList, + BreadcrumbPage, + BreadcrumbSeparator, +} from "@/components/ui/breadcrumb"; +import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs"; +import Hourglass from "@/components/Hourglass"; +import { Separator } from "@/components/ui/separator"; +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, + AlertDialogTrigger, +} from "@/components/ui/alert-dialog"; +import { + Dialog, + DialogContent, + DialogDescription, + DialogFooter, + DialogHeader, + DialogTitle, +} from "@/components/ui/dialog"; + +export default function WorkspaceDetailPage() { + const params = useParams(); + const router = useRouter(); + const { token, currentWorkspace, fetchWorkspaces, switchWorkspace } = + useAuth(); + const { toast } = useToast(); + + const [workspace, setWorkspace] = useState(null); + const [workspaceUsers, setWorkspaceUsers] = useState([]); + const [isLoading, setIsLoading] = useState(true); + const [isRotatingKey, setIsRotatingKey] = useState(false); + const [newApiKey, setNewApiKey] = useState(null); + const [isApiKeyDialogOpen, setIsApiKeyDialogOpen] = useState(false); + + const workspaceId = Number(params.workspaceId); + + useEffect(() => { + const loadWorkspaceData = async () => { + if (!token) return; + + setIsLoading(true); + try { + // Fetch workspace details + const workspaceData = await apiCalls.getWorkspaceById( + token, + workspaceId + ); + setWorkspace(workspaceData); + + // Fetch workspace users + const usersData = await apiCalls.getWorkspaceUsers(token, workspaceId); + setWorkspaceUsers(usersData); + } catch (error) { + console.error("Error loading workspace data:", error); + toast({ + title: "Error", + description: "Failed to load workspace details", + variant: "destructive", + }); + } finally { + setIsLoading(false); + } + }; + + loadWorkspaceData(); + }, [token, workspaceId, toast]); + + const handleRotateApiKey = async () => { + if (!token || !workspace) return; + + setIsRotatingKey(true); + try { + // Need to switch to this workspace first if it's not the current one + if (currentWorkspace?.workspace_id !== workspaceId) { + await switchWorkspace(workspace.workspace_name); + } + + const result = await apiCalls.rotateWorkspaceApiKey(token); + setNewApiKey(result.new_api_key); + setIsApiKeyDialogOpen(true); + + // Refresh workspaces data + await fetchWorkspaces(); + + // Refresh workspace details + const updatedWorkspace = await apiCalls.getWorkspaceById( + token, + workspaceId + ); + setWorkspace(updatedWorkspace); + + toast({ + title: "Success", + description: "API key rotated successfully", + variant: "success", + }); + } catch (error: any) { + toast({ + title: "Error", + description: error.message || "Failed to rotate API key", + variant: "destructive", + }); + } finally { + setIsRotatingKey(false); + } + }; + + const handleCopyApiKey = () => { + if (!newApiKey) return; + + navigator.clipboard.writeText(newApiKey); + toast({ + title: "Copied", + description: "API key copied to clipboard", + }); + }; + + const handleRemoveUser = async (username: string) => { + if (!token) return; + + try { + await apiCalls.removeUserFromWorkspace(token, workspaceId, username); + + // Refresh user list + const updatedUsers = await apiCalls.getWorkspaceUsers(token, workspaceId); + setWorkspaceUsers(updatedUsers); + + toast({ + title: "Success", + description: `${username} has been removed from the workspace`, + variant: "success", + }); + } catch (error: any) { + toast({ + title: "Error", + description: error.message || "Failed to remove user", + variant: "destructive", + }); + } + }; + + if (isLoading) { + return ( +
+
+ + + Loading workspace details... + +
+
+ ); + } + + if (!workspace) { + return ( +
+
+

Workspace Not Found

+

+ The requested workspace could not be found +

+ +
+
+ ); + } + + return ( + <> +
+
+ + + + Home + + + + Workspaces + + + + {workspace.workspace_name} + + + +
+ +
+
+
+ +
+
+

+ {workspace.workspace_name} +

+

+ Workspace ID: {workspace.workspace_id} +

+
+
+ +
+ + + + Overview + Users + API + + + + + + Workspace Overview + + Summary information about this workspace + + + +
+
+

+ Workspace Details +

+
+
Name:
+
{workspace.workspace_name}
+ +
ID:
+
{workspace.workspace_id}
+ +
Created:
+
+ {new Date( + workspace.created_datetime_utc + ).toLocaleDateString()} +
+ +
Last Updated:
+
+ {new Date( + workspace.updated_datetime_utc + ).toLocaleDateString()} +
+ +
+ API Daily Quota: +
+
{workspace.api_daily_quota} calls/day
+ +
+ Content Quota: +
+
{workspace.content_quota} experiments
+
+
+ +
+

+ API Configuration +

+
+
+ API Key Prefix: +
+
+ {workspace.api_key_first_characters}••••• +
+ +
+ Key Last Rotated: +
+
+ {new Date( + workspace.api_key_updated_datetime_utc + ).toLocaleDateString()} +
+
+
+
+
+ {currentWorkspace?.workspace_id !== workspace.workspace_id && ( + + + + )} +
+
+ + + + + + Workspace Users + {!workspace.is_default && ( + + )} + + + {workspace.is_default + ? "This is a default workspace. User management is restricted." + : "Manage users who have access to this workspace"} + + + + {workspaceUsers.length === 0 ? ( +
+ +

No users found

+

+ {workspace.is_default + ? "Default workspaces automatically include all users." + : "Invite users to collaborate in this workspace"} +

+
+ ) : ( +
+
+
User
+
Role
+
Joined
+
Actions
+
+ {workspaceUsers.map((user) => { + // Find the current user to determine if they have admin rights + const isCurrentUserAdmin = + workspaceUsers.find( + (u) => u.username === currentWorkspace?.username + )?.role === "admin"; + + return ( +
+
+
+ {user.first_name} {user.last_name} +
+
+ {user.username} +
+
+
+ {user.role.toLowerCase()} + {user.is_default_workspace && ( + + Default + + )} +
+
+ {new Date( + user.created_datetime_utc + ).toLocaleDateString()} +
+
+ {!workspace.is_default && isCurrentUserAdmin && ( + + + + + + + + Remove user? + + + Are you sure you want to remove{" "} + {user.first_name} {user.last_name} from + this workspace? + + + + { + e.preventDefault(); + e.stopPropagation(); + }} + > + Cancel + + + handleRemoveUser(user.username) + } + > + Remove + + + + + )} +
+
+ ); + })} +
+ )} +
+
+
+ + + + + + API Configuration + + + + Manage API settings for this workspace + + + +
+

Current API Key

+
+ {workspace.api_key_first_characters} + ••••••••••••••••••••••••••• +
+

+ Last updated on{" "} + {new Date( + workspace.api_key_updated_datetime_utc + ).toLocaleString()} +

+
+ +
+

API Usage Limits

+
+
+
+ Daily Quota +
+
+ {workspace.api_daily_quota} +
+
+ API calls per day +
+
+
+
+ Content Quota +
+
+ {workspace.content_quota} +
+
+ Experiments +
+
+
+
+ + + +
+

+ About API Key Rotation +

+

+ When you rotate your API key, the old key will be + immediately invalidated. Any services or applications using + the old key will need to be updated with the new key. Make + sure to copy and save your new key as it will only be shown + once. +

+
+
+
+
+
+
+ + {/* Dialog for showing the new API key */} + + + + New API Key Generated + + Make sure to copy your new API key. You won't be able to see it + again! + + + +
+ {newApiKey} +
+ + + + + +
+
+ + ); +} diff --git a/frontend/src/app/(protected)/workspaces/[workspaceId]/users/invite/page.tsx b/frontend/src/app/(protected)/workspaces/[workspaceId]/users/invite/page.tsx new file mode 100644 index 0000000..d2d2644 --- /dev/null +++ b/frontend/src/app/(protected)/workspaces/[workspaceId]/users/invite/page.tsx @@ -0,0 +1,244 @@ +"use client"; + +import { useState } from "react"; +import { useParams, useRouter } from "next/navigation"; +import { useAuth } from "@/utils/auth"; +import { apiCalls } from "@/utils/api"; +import { useToast } from "@/hooks/use-toast"; +import { z } from "zod"; +import { useForm } from "react-hook-form"; +import { zodResolver } from "@hookform/resolvers/zod"; + +import { ChevronLeftIcon, Send, Users } from "lucide-react"; +import { Button } from "@/components/ui/button"; +import { + Card, + CardContent, + CardDescription, + CardFooter, + CardHeader, + CardTitle, +} from "@/components/ui/card"; +import { + Breadcrumb, + BreadcrumbItem, + BreadcrumbLink, + BreadcrumbList, + BreadcrumbPage, + BreadcrumbSeparator, +} from "@/components/ui/breadcrumb"; +import { + Form, + FormControl, + FormDescription, + FormField, + FormItem, + FormLabel, + FormMessage, +} from "@/components/ui/form"; +import { Input } from "@/components/ui/input"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; + +// Create schema for invitation form +const InviteFormSchema = z.object({ + email: z + .string() + .email("Please enter a valid email address"), + role: z + .string() + .refine(val => ["admin", "read_only"].includes(val), { + message: "Please select a valid role", + }), +}); + +type InviteFormValues = z.infer; + +export default function InviteUserPage() { + const params = useParams(); + const router = useRouter(); + const { token } = useAuth(); + const { toast } = useToast(); + const [isInviting, setIsInviting] = useState(false); + const [inviteSuccess, setInviteSuccess] = useState(false); + + const workspaceId = Number(params.workspaceId); + + // Initialize form + const form = useForm({ + resolver: zodResolver(InviteFormSchema), + defaultValues: { + email: "", + role: "read_only", // Default to read-only + }, + }); + + const onSubmit = async (data: InviteFormValues) => { + if (!token) return; + + setIsInviting(true); + try { + // First need to get workspace name + const workspace = await apiCalls.getWorkspaceById(token, workspaceId); + + // Send invitation + const response = await apiCalls.inviteUserToWorkspace( + token, + data.email, + workspace.workspace_name, + data.role + ); + + // Show success message + setInviteSuccess(true); + toast({ + title: "Invitation sent", + description: `${data.email} has been invited to the workspace`, + variant: "success", + }); + + // Reset form + form.reset(); + } catch (error: any) { + toast({ + title: "Error", + description: error.message || "Failed to send invitation", + variant: "destructive", + }); + } finally { + setIsInviting(false); + } + }; + + return ( +
+
+ + + + Home + + + + Workspaces + + + + Workspace + + + + Invite User + + + +
+ +
+
+
+ +
+

Invite User

+
+ +
+ + + + Invite User to Workspace + + Send an invitation to join this workspace + + + +
+ + ( + + Email + + + + + The email address of the person you want to invite + + + + )} + /> + + ( + + Role + + + Administrators can manage workspace settings and users. Regular users have read-only access. + + + + )} + /> + + + + + + {inviteSuccess && ( +
+

+ Invitation Sent Successfully +

+

+ An email has been sent to the user with instructions to join the workspace. +

+
+ )} +
+ + + +
+
+ ); +} \ No newline at end of file diff --git a/frontend/src/app/(protected)/workspaces/page.tsx b/frontend/src/app/(protected)/workspaces/page.tsx new file mode 100644 index 0000000..2e7b6bc --- /dev/null +++ b/frontend/src/app/(protected)/workspaces/page.tsx @@ -0,0 +1,243 @@ +"use client"; + +import React, { useState, useEffect } from "react"; +import { useAuth } from "@/utils/auth"; +import { Building, Plus, Settings } from "lucide-react"; +import { Button } from "@/components/ui/button"; +import { + Card, + CardContent, + CardDescription, + CardFooter, + CardHeader, + CardTitle, +} from "@/components/ui/card"; +import { + Breadcrumb, + BreadcrumbItem, + BreadcrumbLink, + BreadcrumbList, + BreadcrumbPage, + BreadcrumbSeparator, +} from "@/components/ui/breadcrumb"; +import Hourglass from "@/components/Hourglass"; +import { Badge } from "@/components/ui/badge"; +import { Separator } from "@/components/ui/separator"; +import Link from "next/link"; +import { CreateWorkspaceDialog } from "@/components/create-workspace-dialog"; +import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs"; +import router from "next/router"; + +export default function WorkspacesPage() { + const { workspaces, currentWorkspace, fetchWorkspaces, switchWorkspace } = useAuth(); + const [isLoading, setIsLoading] = useState(true); + const [isCreateDialogOpen, setIsCreateDialogOpen] = useState(false); + + useEffect(() => { + const loadWorkspaces = async () => { + setIsLoading(true); + try { + await fetchWorkspaces(); + } catch (error) { + console.error("Error loading workspaces:", error); + } finally { + setIsLoading(false); + } + }; + + loadWorkspaces(); + }, []); + + const handleWorkspaceSwitch = async (workspaceName: string) => { + try { + setIsLoading(true); + await switchWorkspace(workspaceName); + } catch (error) { + console.error("Error switching workspace:", error); + } finally { + setIsLoading(false); + } + }; + + if (isLoading) { + return ( +
+
+ + Loading workspaces... +
+
+ ); + } + + return ( + <> +
+
+ + + + Home + + + + Workspaces + + + +
+ +
+
+

Workspaces

+

+ Manage your workspaces and team members +

+
+ +
+ + + + All Workspaces + Current Workspace + + + +
+ {workspaces.map((workspace) => ( + + +
+ {workspace.is_default && ( + Default + )} + {currentWorkspace?.workspace_id === workspace.workspace_id && ( + Current + )} +
+
+
+ +
+ {workspace.workspace_name} +
+ + ID: {workspace.workspace_id} + +
+ +
+
+ API Quota: + {workspace.api_daily_quota} calls/day +
+
+ Content Quota: + {workspace.content_quota} experiments +
+
+ API Key: + {workspace.api_key_first_characters}••••• +
+
+ Created: + {new Date(workspace.created_datetime_utc).toLocaleDateString()} +
+
+
+ + + + + + + +
+ ))} +
+
+ + + {currentWorkspace && ( +
+ + +
+
+ +
+
+ {currentWorkspace.workspace_name} + + Workspace ID: {currentWorkspace.workspace_id} + +
+
+
+ +
+
+

Workspace Details

+
+
Created:
+
{new Date(currentWorkspace.created_datetime_utc).toLocaleDateString()}
+ +
Last Updated:
+
{new Date(currentWorkspace.updated_datetime_utc).toLocaleDateString()}
+ +
API Daily Quota:
+
{currentWorkspace.api_daily_quota} calls/day
+ +
Content Quota:
+
{currentWorkspace.content_quota} experiments
+
+
+ +
+

API Configuration

+
+
API Key Prefix:
+
{currentWorkspace.api_key_first_characters}•••••
+ +
Key Last Rotated:
+
{new Date(currentWorkspace.api_key_updated_datetime_utc).toLocaleDateString()}
+
+
+
+
+ + + +
+
+ )} +
+
+
+ + + + ); +} diff --git a/frontend/src/app/(protected)/workspaces/types.ts b/frontend/src/app/(protected)/workspaces/types.ts new file mode 100644 index 0000000..6a0c8de --- /dev/null +++ b/frontend/src/app/(protected)/workspaces/types.ts @@ -0,0 +1,22 @@ +export type WorkspaceUser = { + user_id: number; + username: string; + first_name: string; + last_name: string; + role: string; + is_default_workspace: boolean; + created_datetime_utc: string; +}; + +export type Workspace = { + workspace_id: number; + workspace_name: string; + api_key_first_characters: string; + api_daily_quota: number; + content_quota: number; + created_datetime_utc: string; + updated_datetime_utc: string; + api_key_updated_datetime_utc: string; + is_default: boolean; +}; + diff --git a/frontend/src/components/WorkspaceSelector.tsx b/frontend/src/components/WorkspaceSelector.tsx deleted file mode 100644 index 4692888..0000000 --- a/frontend/src/components/WorkspaceSelector.tsx +++ /dev/null @@ -1,142 +0,0 @@ -"use client"; - -import React, { useState } from "react"; -import { useAuth } from "@/utils/auth"; -import { Button } from "@/components/catalyst/button"; -import { - Dialog, - DialogActions, - DialogBody, - DialogDescription, - DialogTitle, -} from "@/components/catalyst/dialog"; -import { BuildingOfficeIcon, ChevronUpDownIcon, PlusIcon } from "@heroicons/react/20/solid"; -import { - DropdownItem, - DropdownLabel, - DropdownMenu, - DropdownButton, - DropdownDivider, - Dropdown, -} from "@/components/catalyst/dropdown"; -import { useToast } from "@/hooks/use-toast"; -import { useRouter } from "next/navigation"; - -export default function WorkspaceSelector() { - const { currentWorkspace, workspaces, switchWorkspace, isLoading } = useAuth(); - const [isOpen, setIsOpen] = useState(false); - const { toast } = useToast(); - const router = useRouter(); - - const handleSwitchWorkspace = async (workspaceName: string) => { - try { - await switchWorkspace(workspaceName); - toast({ - title: "Workspace Changed", - description: `Switched to workspace: ${workspaceName}`, - }); - } catch (error) { - toast({ - title: "Error", - description: "Failed to switch workspace", - variant: "destructive", - }); - } - }; - - const handleCreateWorkspace = () => { - router.push("/workspace/create"); - }; - - if (isLoading || !currentWorkspace) { - return ( - - ); - } - - return ( -
- - - - - {currentWorkspace.workspace_name} - - - - - - {workspaces.map((workspace) => ( - handleSwitchWorkspace(workspace.workspace_name)} - > - - {workspace.workspace_name} - - ))} - - - - - - Create New Workspace - - - - - setIsOpen(false)}> - Switch Workspace - - Select a workspace to switch to - - -
- {workspaces.map((workspace) => ( -
{ - handleSwitchWorkspace(workspace.workspace_name); - setIsOpen(false); - }} - > -
{workspace.workspace_name}
- {workspace.workspace_id === currentWorkspace.workspace_id && ( -
Current
- )} -
- ))} -
-
- - - - -
-
- ); -} diff --git a/frontend/src/components/app-sidebar.tsx b/frontend/src/components/app-sidebar.tsx index 6c8ee29..8495693 100644 --- a/frontend/src/components/app-sidebar.tsx +++ b/frontend/src/components/app-sidebar.tsx @@ -1,16 +1,11 @@ "use client"; import * as React from "react"; import { - AudioWaveform, ArrowLeftRightIcon, LayoutDashboardIcon, - Command, - Frame, - GalleryVerticalEnd, - Map, - PieChart, - Settings2, FlaskConicalIcon, + Settings2, + Building, } from "lucide-react"; import { NavMain } from "@/components/nav-main"; import { NavRecentExperiments } from "@/components/nav-recent-experiments"; @@ -23,128 +18,68 @@ import { SidebarHeader, SidebarRail, } from "@/components/ui/sidebar"; -import api from "@/utils/api"; import { useAuth } from "@/utils/auth"; -type UserDetails = { - username: string; - firstName: string; - lastName: string; - isActive: boolean; - isVerified: boolean; -}; - -const getUserDetails = async (token: string | null) => { - try { - const response = await api.get("/user", { - headers: { - Authorization: `Bearer ${token}`, - }, - }); - - return { - username: response.data.username, - firstName: response.data.first_name, - lastName: response.data.last_name, - isActive: response.data.is_active, - isVerified: response.data.is_verified, - } as UserDetails; - } catch (error: unknown) { - if (error instanceof Error) { - throw new Error(`Error fetching user details: ${error.message}`); - } else { - throw new Error("Error fetching user details"); - } - } -}; +const AppSidebar = React.memo(function AppSidebar({ + ...props +}: React.ComponentProps) { + const { user, firstName, lastName } = useAuth(); -// This is sample data. -const data = { - user: { - name: "shadcn", - email: "m@example.com", - avatar: "/avatars/shadcn.jpg", - }, - workspaces: [ - { - name: "Acme Inc", - logo: GalleryVerticalEnd, - plan: "Enterprise", - }, - { - name: "Acme Corp.", - logo: AudioWaveform, - plan: "Startup", - }, - { - name: "Evil Corp.", - logo: Command, - plan: "Free", - }, - ], - navMain: [ + const navMain = [ { title: "Experiments", url: "/experiments", icon: FlaskConicalIcon, }, - { - title: "Integration", - url: "/integration", - icon: ArrowLeftRightIcon, - }, { title: "Dashboard", url: "#", icon: LayoutDashboardIcon, }, + { + title: "Workspaces", + url: "/workspaces", + icon: Building, + }, { title: "Settings", url: "#", icon: Settings2, }, - ], - recentExperiments: [ + ]; + + const recentExperiments = [ { name: "New onboarding flows", url: "#", - icon: Frame, + icon: FlaskConicalIcon, }, { name: "3 different voices", url: "#", - icon: PieChart, + icon: FlaskConicalIcon, }, { name: "AI responses", url: "#", - icon: Map, + icon: FlaskConicalIcon, }, - ], -}; -const AppSidebar = React.memo(function AppSidebar({ - ...props -}: React.ComponentProps) { - const { token } = useAuth(); - const [userDetails, setUserDetails] = React.useState( - null - ); + ]; + + const userDetails = { + firstName: firstName || "?", + lastName: lastName || "?", + username: user || "loading", + }; - React.useEffect(() => { - if (token) { - getUserDetails(token) - .then((data) => setUserDetails(data)) - .catch((error) => console.error(error)); - } - }, [token]); return ( - + - - + + diff --git a/frontend/src/components/create-workspace-dialog.tsx b/frontend/src/components/create-workspace-dialog.tsx new file mode 100644 index 0000000..8a5b75b --- /dev/null +++ b/frontend/src/components/create-workspace-dialog.tsx @@ -0,0 +1,152 @@ +"use client"; + +import { useState } from "react"; +import { apiCalls } from "@/utils/api"; +import { useAuth } from "@/utils/auth"; +import { useToast } from "@/hooks/use-toast"; +import { useRouter } from "next/navigation"; +import { z } from "zod"; +import { useForm } from "react-hook-form"; +import { zodResolver } from "@hookform/resolvers/zod"; + +import { + Dialog, + DialogContent, + DialogDescription, + DialogFooter, + DialogHeader, + DialogTitle, +} from "@/components/ui/dialog"; +import { Button } from "@/components/ui/button"; +import { + Form, + FormControl, + FormDescription, + FormField, + FormItem, + FormLabel, + FormMessage, +} from "@/components/ui/form"; +import { Input } from "@/components/ui/input"; + +// Create schema for workspace form +const WorkspaceFormSchema = z.object({ + workspaceName: z + .string() + .min(3, "Workspace name must be at least 3 characters") + .max(50, "Workspace name must be less than 50 characters"), + apiDailyQuota: z + .number() + .int("API quota must be an integer") + .min(1, "API quota must be at least 1") + .optional(), + contentQuota: z + .number() + .int("Content quota must be an integer") + .min(1, "Content quota must be at least 1") + .optional(), +}); + +type WorkspaceFormValues = z.infer; + +interface CreateWorkspaceDialogProps { + open: boolean; + onOpenChange: (open: boolean) => void; +} + +export function CreateWorkspaceDialog({ + open, + onOpenChange, +}: CreateWorkspaceDialogProps) { + const { token, fetchWorkspaces } = useAuth(); + const router = useRouter(); + const { toast } = useToast(); + const [isCreating, setIsCreating] = useState(false); + + // Initialize form + const form = useForm({ + resolver: zodResolver(WorkspaceFormSchema), + defaultValues: { + workspaceName: "", + apiDailyQuota: undefined, + contentQuota: undefined, + }, + }); + + async function onSubmit(data: WorkspaceFormValues) { + if (!token) return; + + setIsCreating(true); + try { + await apiCalls.createWorkspace( + token, + data.workspaceName, + data.apiDailyQuota, + data.contentQuota + ); + + toast({ + title: "Workspace created", + description: `${data.workspaceName} workspace was created successfully.`, + variant: "success", + }); + + // Refresh the workspaces list + await fetchWorkspaces(); + + // Close the dialog and reset form + onOpenChange(false); + form.reset(); + + // Navigate to workspaces page + router.push("/workspaces"); + } catch (error: any) { + toast({ + title: "Error creating workspace", + description: error.message || "Failed to create workspace. Please try again.", + variant: "destructive", + }); + } finally { + setIsCreating(false); + } + } + + return ( + + + + Create New Workspace + + Create a new workspace for your team or project + + + +
+ + ( + + Workspace Name + + + + + This is the name that will be displayed for your workspace + + + + )} + /> + + + + + +
+
+ ); +} diff --git a/frontend/src/components/ui/alert-dialog.tsx b/frontend/src/components/ui/alert-dialog.tsx new file mode 100644 index 0000000..1e2629a --- /dev/null +++ b/frontend/src/components/ui/alert-dialog.tsx @@ -0,0 +1,143 @@ +"use client" + +import * as React from "react" +import * as AlertDialogPrimitive from "@radix-ui/react-alert-dialog" + +import { cn } from "@/lib/utils" +import { buttonVariants } from "@/components/ui/button" + +const AlertDialog = AlertDialogPrimitive.Root +const AlertDialogTrigger = AlertDialogPrimitive.Trigger +const AlertDialogPortal = AlertDialogPrimitive.Portal + +const AlertDialogOverlay = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +AlertDialogOverlay.displayName = AlertDialogPrimitive.Overlay.displayName + +const AlertDialogContent = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + + + + +)) +AlertDialogContent.displayName = AlertDialogPrimitive.Content.displayName + +const AlertDialogHeader = ({ + className, + ...props +}: React.HTMLAttributes) => ( +
+) +AlertDialogHeader.displayName = "AlertDialogHeader" + +const AlertDialogFooter = ({ + className, + ...props +}: React.HTMLAttributes) => ( +
+) +AlertDialogFooter.displayName = "AlertDialogFooter" + +const AlertDialogTitle = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +AlertDialogTitle.displayName = AlertDialogPrimitive.Title.displayName + +const AlertDialogDescription = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +AlertDialogDescription.displayName = + AlertDialogPrimitive.Description.displayName + +const AlertDialogAction = React.forwardRef< + HTMLButtonElement, + React.ButtonHTMLAttributes +>(({ className, ...props }, ref) => ( +
@@ -508,6 +535,9 @@ export default function WorkspaceDetailPage() { {new Date( workspace.api_key_updated_datetime_utc ).toLocaleString()} + {workspace.api_key_rotated_by_username && ( + by {workspace.api_key_rotated_by_username} + )}

@@ -541,6 +571,49 @@ export default function WorkspaceDetailPage() { +
+

+ + API Key Rotation History +

+
+
+
Date & Time
+
Rotated By
+
Key Prefix
+
+ {isLoadingHistory ? ( +
+ +

Loading history...

+
+ ) : keyRotationHistory.length > 0 ? ( + keyRotationHistory.map((rotation) => ( +
+
+ {new Date(rotation.rotation_datetime_utc).toLocaleString()} +
+
+ {rotation.rotated_by_username} +
+
+ {rotation.key_first_characters}••••• +
+
+ )) + ) : ( +
+ No rotation history available +
+ )} +
+

+ This table shows the complete history of API key rotations. +

+
+ + +

About API Key Rotation diff --git a/frontend/src/app/(protected)/workspaces/types.ts b/frontend/src/app/(protected)/workspaces/types.ts index 6a0c8de..f604ff5 100644 --- a/frontend/src/app/(protected)/workspaces/types.ts +++ b/frontend/src/app/(protected)/workspaces/types.ts @@ -1,4 +1,18 @@ -export type WorkspaceUser = { +export interface Workspace { + workspace_id: number; + workspace_name: string; + api_daily_quota: number; + content_quota: number; + api_key_first_characters: string; + api_key_updated_datetime_utc: string; + api_key_rotated_by_user_id?: number; + api_key_rotated_by_username?: string; + created_datetime_utc: string; + updated_datetime_utc: string; + is_default: boolean; +} + +export interface WorkspaceUser { user_id: number; username: string; first_name: string; @@ -6,17 +20,13 @@ export type WorkspaceUser = { role: string; is_default_workspace: boolean; created_datetime_utc: string; -}; +} -export type Workspace = { +export interface ApiKeyRotation { + rotation_id: number; workspace_id: number; - workspace_name: string; - api_key_first_characters: string; - api_daily_quota: number; - content_quota: number; - created_datetime_utc: string; - updated_datetime_utc: string; - api_key_updated_datetime_utc: string; - is_default: boolean; -}; - + rotated_by_user_id: number; + rotated_by_username: string; + key_first_characters: string; + rotation_datetime_utc: string; +} diff --git a/frontend/src/utils/api.ts b/frontend/src/utils/api.ts index 597e6e0..f818f30 100644 --- a/frontend/src/utils/api.ts +++ b/frontend/src/utils/api.ts @@ -179,12 +179,16 @@ const switchWorkspace = async (token: string | null, workspaceName: string) => { } }; -const createWorkspace = async (token: string | null, workspaceName: string, - apiDailyQuota?: number, contentQuota?: number) => { +const createWorkspace = async ( + token: string | null, + workspaceName: string, + apiDailyQuota?: number, + contentQuota?: number +) => { try { const response = await api.post( "/workspace/", - { + { workspace_name: workspaceName, api_daily_quota: apiDailyQuota, content_quota: contentQuota @@ -317,6 +321,19 @@ const removeUserFromWorkspace = async ( } }; +const getWorkspaceKeyHistory = async (token: string | null, workspaceId: number) => { + try { + const response = await api.get(`/workspace/${workspaceId}/key-history`, { + headers: { + Authorization: `Bearer ${token}`, + }, + }); + return response.data; + } catch (error) { + throw new Error("Error fetching API key history"); + } +}; + export const apiCalls = { getUser, getLoginToken, @@ -336,5 +353,6 @@ export const apiCalls = { getWorkspaceUsers, inviteUserToWorkspace, removeUserFromWorkspace, + getWorkspaceKeyHistory, }; export default api; From daf1f9fc29fad583ba3e887d48e0867131792d12 Mon Sep 17 00:00:00 2001 From: Jay Prakash <0freerunning@gmail.com> Date: Sun, 4 May 2025 23:51:22 +0300 Subject: [PATCH 12/74] Fix the tests --- backend/add_users_to_db.py | 117 ++++++- backend/app/__init__.py | 2 +- backend/app/auth/dependencies.py | 126 ++++++- backend/app/bayes_ab/routers.py | 96 +++--- backend/app/contextual_mab/routers.py | 33 +- backend/app/mab/routers.py | 32 +- backend/app/workspaces/models.py | 33 +- .../versions/9f7482ba882f_workspace_model.py | 123 +++++++ .../aee60a81b8b6_api_keys_history_all.py | 46 --- .../versions/d9f7a309944e_workspace_model.py | 72 ---- .../versions/ee03b9fceb6f_api_keys_history.py | 38 --- backend/tests/conftest.py | 120 ++++++- backend/tests/pytest.ini | 5 + backend/tests/test_auto_fail.py | 55 ++- backend/tests/test_bayes_ab.py | 151 ++++++++- backend/tests/test_cmabs.py | 87 +++-- backend/tests/test_mabs.py | 73 ++-- backend/tests/test_messages.py | 21 ++ backend/tests/test_notifications_job.py | 29 +- backend/tests/test_users.py | 55 ++- backend/tests/test_workspace.py | 319 ++++++++++++++++++ 21 files changed, 1284 insertions(+), 349 deletions(-) create mode 100644 backend/migrations/versions/9f7482ba882f_workspace_model.py delete mode 100644 backend/migrations/versions/aee60a81b8b6_api_keys_history_all.py delete mode 100644 backend/migrations/versions/d9f7a309944e_workspace_model.py delete mode 100644 backend/migrations/versions/ee03b9fceb6f_api_keys_history.py create mode 100644 backend/tests/pytest.ini create mode 100644 backend/tests/test_workspace.py diff --git a/backend/add_users_to_db.py b/backend/add_users_to_db.py index 01a6e62..734d348 100644 --- a/backend/add_users_to_db.py +++ b/backend/add_users_to_db.py @@ -1,10 +1,12 @@ import asyncio import os from datetime import datetime, timezone +from typing import Optional, Union from redis import asyncio as aioredis from sqlalchemy import select from sqlalchemy.exc import MultipleResultsFound, NoResultFound +from sqlalchemy.orm import Session # Import Session type from app.config import REDIS_HOST from app.database import get_session @@ -15,6 +17,7 @@ get_password_salted_hash, setup_logger, ) +from app.workspaces.models import UserRoles, UserWorkspaceDB, WorkspaceDB logger = setup_logger() @@ -45,7 +48,7 @@ ) -async def async_redis_operations(key: str, value: int | None) -> None: +async def async_redis_operations(key: str, value: Optional[int]) -> None: """ Asynchronous Redis operations to set the remaining API calls for a user. """ @@ -56,7 +59,7 @@ async def async_redis_operations(key: str, value: int | None) -> None: await redis.aclose() -def run_redis_async_tasks(key: str, value: int | str) -> None: +def run_redis_async_tasks(key: str, value: Union[int, str]) -> None: """ Run asynchronous Redis operations to set the remaining API calls for a user. """ @@ -66,23 +69,127 @@ def run_redis_async_tasks(key: str, value: int | str) -> None: loop.run_until_complete(async_redis_operations(key, value_int)) +def ensure_default_workspace(db_session: Session, user_db: UserDB) -> None: + """ + Ensure that a user has a default workspace. + + Parameters + ---------- + db_session + The database session. + user_db + The user DB record. + """ + # Check if user already has a workspace + stmt = select(UserWorkspaceDB).where(UserWorkspaceDB.user_id == user_db.user_id) + result = db_session.execute(stmt) + existing_workspace = result.scalar_one_or_none() + + if existing_workspace: + logger.info( + f"User {user_db.username} already has workspace relationship: " + f"{existing_workspace.workspace_id}" + ) + # Check if any workspace is set as default + stmt = select(UserWorkspaceDB).where( + UserWorkspaceDB.user_id == user_db.user_id, + UserWorkspaceDB.default_workspace, # Fixed boolean comparison + ) + result = db_session.execute(stmt) + default_workspace = result.scalar_one_or_none() + + if default_workspace: + logger.info( + f"User {user_db.username} already has default workspace: " + f"{default_workspace.workspace_id}" + ) + return + else: + # Set first workspace as default + existing_workspace.default_workspace = True + db_session.add(existing_workspace) + db_session.commit() + logger.info( + f"Set workspace {existing_workspace.workspace_id} as default for " + f"{user_db.username}" + ) + return + + # Create a default workspace for the user + workspace_name = f"{user_db.username}'s Workspace" + + # Check if workspace with this name already exists + stmt = select(WorkspaceDB).where(WorkspaceDB.workspace_name == workspace_name) + result = db_session.execute(stmt) + existing_workspace_db = result.scalar_one_or_none() + + if existing_workspace_db: + workspace_db = existing_workspace_db + logger.info( + f"Workspace '{workspace_name}' already exists with ID " + f"{workspace_db.workspace_id}" + ) + else: + # Create new workspace + workspace_db = WorkspaceDB( + workspace_name=workspace_name, + api_daily_quota=100, + content_quota=10, + created_datetime_utc=datetime.now(timezone.utc), + updated_datetime_utc=datetime.now(timezone.utc), + is_default=True, + hashed_api_key=get_key_hash("workspace-api-key-" + workspace_name), + api_key_first_characters="works", + api_key_updated_datetime_utc=datetime.now(timezone.utc), + api_key_rotated_by_user_id=user_db.user_id, + ) + db_session.add(workspace_db) + db_session.commit() + logger.info( + f"Created workspace '{workspace_name}' with ID {workspace_db.workspace_id}" + ) + + # Create user-workspace relationship + user_workspace = UserWorkspaceDB( + user_id=user_db.user_id, + workspace_id=workspace_db.workspace_id, + user_role=UserRoles.ADMIN, + default_workspace=True, + created_datetime_utc=datetime.now(timezone.utc), + updated_datetime_utc=datetime.now(timezone.utc), + ) + db_session.add(user_workspace) + db_session.commit() + logger.info( + f"Created workspace relationship for user {user_db.username} with workspace " + f"{workspace_db.workspace_id}" + ) + + if __name__ == "__main__": db_session = next(get_session()) stmt = select(UserDB).where(UserDB.username == user_db.username) result = db_session.execute(stmt) try: - result.one() + existing_user = result.one() logger.info(f"User with username {user_db.username} already exists.") + user_db = existing_user[0] except NoResultFound: db_session.add(user_db) + db_session.flush() + logger.info(f"User with username {user_db.username} added to local database.") run_redis_async_tasks( f"remaining-calls:{user_db.username}", user_db.api_daily_quota ) - logger.info(f"User with username {user_db.username} added to local database.") - except MultipleResultsFound: logger.error( f"Multiple users with username {user_db.username} found in local database." ) + # Just get the first one + existing_users = result.all() + user_db = existing_users[0][0] + + # Ensure the user has a default workspace + ensure_default_workspace(db_session, user_db) db_session.commit() diff --git a/backend/app/__init__.py b/backend/app/__init__.py index 7e37f2a..550cb07 100644 --- a/backend/app/__init__.py +++ b/backend/app/__init__.py @@ -27,7 +27,7 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]: yield - await app.state.redis.close() + await app.state.redis.aclose() logger.info("Application finished") diff --git a/backend/app/auth/dependencies.py b/backend/app/auth/dependencies.py index a4d18f3..3560783 100644 --- a/backend/app/auth/dependencies.py +++ b/backend/app/auth/dependencies.py @@ -10,10 +10,14 @@ OAuth2PasswordBearer, ) from jwt.exceptions import InvalidTokenError +from redis.asyncio import Redis +from sqlalchemy import case, select +from sqlalchemy.exc import NoResultFound from sqlalchemy.ext.asyncio import AsyncSession from ..config import CHECK_API_LIMIT, DEFAULT_API_QUOTA, DEFAULT_EXPERIMENTS_QUOTA from ..database import get_async_session +from ..users.exceptions import UserNotFoundError from ..users.models import ( UserDB, get_user_by_api_key, @@ -21,14 +25,17 @@ save_user_to_db, update_user_verification_status, ) -from ..users.exceptions import UserNotFoundError from ..users.schemas import UserCreate from ..utils import ( + encode_api_limit, generate_key, + get_key_hash, setup_logger, update_api_limits, verify_password_salted_hash, ) +from ..workspaces.models import UserWorkspaceDB, WorkspaceDB +from ..workspaces.schemas import UserRoles from .config import ( ACCESS_TOKEN_EXPIRE_MINUTES, JWT_ALGORITHM, @@ -66,6 +73,66 @@ async def authenticate_key( raise HTTPException(status_code=403, detail="Invalid API key") from e +async def authenticate_workspace_key( + asession: AsyncSession = Depends(get_async_session), + credentials: HTTPAuthorizationCredentials = Depends(bearer), +) -> UserDB: + """ + Authenticate using workspace API key. + Returns the user associated with the workspace for the request context. + """ + token = credentials.credentials + try: + # Check if the token matches any workspace API key + hashed_token = get_key_hash(token) + workspace_stmt = select(WorkspaceDB).where( + WorkspaceDB.hashed_api_key == hashed_token + ) + workspace_result = await asession.execute(workspace_stmt) + workspace = workspace_result.scalar_one_or_none() + + if not workspace: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Invalid workspace API key", + ) + + # Find a user in this workspace to use as context + # Prioritize admin users for better permission context + user_stmt = ( + select(UserDB) + .join(UserWorkspaceDB, UserWorkspaceDB.user_id == UserDB.user_id) + .where(UserWorkspaceDB.workspace_id == workspace.workspace_id) + .where(UserDB.is_active) # Fixed boolean comparison + .order_by(case((UserWorkspaceDB.user_role == UserRoles.ADMIN, 0), else_=1)) + .limit(1) + ) + + user_result = await asession.execute(user_stmt) + user = user_result.scalar_one_or_none() + + if not user: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="No active users associated with this workspace", + ) + + # Add the workspace context to the user object + user.current_workspace = workspace + + return user + except NoResultFound as err: + # Fixed exception chaining + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail="Invalid workspace API key" + ) from err + except Exception as e: + logger.error(f"Error authenticating workspace API key: {str(e)}") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail="Authorization error" + ) from e + + async def authenticate_credentials( *, username: str, password: str, asession: AsyncSession ) -> Optional[AuthenticatedUser]: @@ -189,7 +256,7 @@ async def get_verified_user( return user_db -def create_access_token(username: str, workspace_name: str = None) -> str: +def create_access_token(username: str, workspace_name: Optional[str] = None) -> str: """ Create an access token for the user """ @@ -209,6 +276,61 @@ def create_access_token(username: str, workspace_name: str = None) -> str: return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM) +async def update_workspace_api_limits( + redis: Redis, workspace_id: int, api_daily_quota: int | None +) -> None: + """ + Update the API limits for workspace in Redis + """ + now = datetime.now(timezone.utc) + next_midnight = (now + timedelta(days=1)).replace( + hour=0, minute=0, second=0, microsecond=0 + ) + key = f"workspace-remaining-calls:{workspace_id}" + expire_at = int(next_midnight.timestamp()) + await redis.set(key, encode_api_limit(api_daily_quota)) + if api_daily_quota is not None: + await redis.expireat(key, expire_at) + + +async def workspace_rate_limiter( + request: Request, + user_db: UserDB = Depends(authenticate_workspace_key), +) -> None: + """ + Rate limiter for the API calls using workspace quota instead of user quota. + """ + if CHECK_API_LIMIT is False: + return + + workspace = user_db.current_workspace + key = f"workspace-remaining-calls:{workspace.workspace_id}" + redis = request.app.state.redis + ttl = await redis.ttl(key) + + # if key does not exist, set the key and value + if ttl == REDIS_KEY_EXPIRED: + await update_workspace_api_limits( + redis, workspace.workspace_id, workspace.api_daily_quota + ) + + nb_remaining = await redis.get(key) + + if nb_remaining != b"None": + nb_remaining = int(nb_remaining) + if nb_remaining <= 0: + raise HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + detail=( + "Workspace API call limit reached. Please try again tomorrow " + "or upgrade your plan." + ), + ) + await update_workspace_api_limits( + redis, workspace.workspace_id, nb_remaining - 1 + ) + + async def rate_limiter( request: Request, user_db: UserDB = Depends(authenticate_key), diff --git a/backend/app/bayes_ab/routers.py b/backend/app/bayes_ab/routers.py index 923c582..219c0e0 100644 --- a/backend/app/bayes_ab/routers.py +++ b/backend/app/bayes_ab/routers.py @@ -5,7 +5,7 @@ from fastapi.exceptions import HTTPException from sqlalchemy.ext.asyncio import AsyncSession -from ..auth.dependencies import authenticate_key, get_verified_user +from ..auth.dependencies import authenticate_workspace_key, get_verified_user from ..database import get_async_session from ..models import get_notifications_from_db, save_notifications_to_db from ..schemas import NotificationsResponse, ObservationType @@ -52,20 +52,17 @@ async def create_ab_experiment( user_role = await get_user_role_in_workspace( asession=asession, user_db=user_db, workspace_db=workspace_db ) - + if user_role != UserRoles.ADMIN: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Only workspace administrators can create experiments.", ) - + bayes_ab = await save_bayes_ab_to_db( - experiment, - user_db.user_id, - workspace_db.workspace_id, - asession + experiment, user_db.user_id, workspace_db.workspace_id, asession ) - + notifications = await save_notifications_to_db( experiment_id=bayes_ab.experiment_id, user_id=user_db.user_id, @@ -88,11 +85,9 @@ async def get_bayes_abs( Get details of all experiments in the user's current workspace. """ workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - + experiments = await get_all_bayes_ab_experiments( - user_db.user_id, - workspace_db.workspace_id, - asession + user_db.user_id, workspace_db.workspace_id, asession ) all_experiments = [] @@ -127,12 +122,9 @@ async def get_bayes_ab( Get details of experiment with the provided `experiment_id`. """ workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - + experiment = await get_bayes_ab_experiment_by_id( - experiment_id, - user_db.user_id, - workspace_db.workspace_id, - asession + experiment_id, user_db.user_id, workspace_db.workspace_id, asession ) if experiment is None: @@ -161,23 +153,22 @@ async def delete_bayes_ab( Delete the experiment with the provided `experiment_id`. """ try: - workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - + workspace_db = await get_user_default_workspace( + asession=asession, user_db=user_db + ) + user_role = await get_user_role_in_workspace( asession=asession, user_db=user_db, workspace_db=workspace_db ) - + if user_role != UserRoles.ADMIN: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Only workspace administrators can delete experiments.", ) - + experiment = await get_bayes_ab_experiment_by_id( - experiment_id, - user_db.user_id, - workspace_db.workspace_id, - asession + experiment_id, user_db.user_id, workspace_db.workspace_id, asession ) if experiment is None: raise HTTPException( @@ -185,12 +176,9 @@ async def delete_bayes_ab( ) await delete_bayes_ab_experiment_by_id( - experiment_id, - user_db.user_id, - workspace_db.workspace_id, - asession + experiment_id, user_db.user_id, workspace_db.workspace_id, asession ) - + return {"message": f"Experiment with id {experiment_id} deleted successfully."} except Exception as e: raise HTTPException(status_code=500, detail=f"Error: {e}") from e @@ -201,19 +189,17 @@ async def draw_arm( experiment_id: int, draw_id: Optional[str] = None, client_id: Optional[str] = None, - user_db: UserDB = Depends(authenticate_key), + user_db: UserDB = Depends(authenticate_workspace_key), asession: AsyncSession = Depends(get_async_session), ) -> BayesianABDrawResponse: """ Get which arm to pull next for provided experiment. """ - workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - + # Get workspace from user context + workspace_id = user_db.current_workspace.workspace_id + experiment = await get_bayes_ab_experiment_by_id( - experiment_id, - user_db.user_id, - workspace_db.workspace_id, - asession + experiment_id, user_db.user_id, workspace_id, asession ) if experiment is None: @@ -283,21 +269,22 @@ async def save_observation_for_arm( experiment_id: int, draw_id: str, outcome: float, - user_db: UserDB = Depends(authenticate_key), + user_db: UserDB = Depends(authenticate_workspace_key), asession: AsyncSession = Depends(get_async_session), ) -> BayesABArmResponse: """ Update the arm with the provided `arm_id` for the given `experiment_id` based on the `outcome`. """ - workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - + # Get workspace from user context + workspace_id = user_db.current_workspace.workspace_id + # Get and validate experiment experiment, draw = await validate_experiment_and_draw( experiment_id=experiment_id, draw_id=draw_id, user_id=user_db.user_id, - workspace_id=workspace_db.workspace_id, + workspace_id=workspace_id, asession=asession, ) @@ -316,19 +303,17 @@ async def save_observation_for_arm( ) async def get_outcomes( experiment_id: int, - user_db: UserDB = Depends(authenticate_key), + user_db: UserDB = Depends(authenticate_workspace_key), asession: AsyncSession = Depends(get_async_session), ) -> list[BayesianABObservationResponse]: """ Get the outcomes for the experiment. """ - workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - + # Get workspace from user context + workspace_id = user_db.current_workspace.workspace_id + experiment = await get_bayes_ab_experiment_by_id( - experiment_id, - user_db.user_id, - workspace_db.workspace_id, - asession + experiment_id, user_db.user_id, workspace_id, asession ) if not experiment: raise HTTPException( @@ -350,20 +335,18 @@ async def get_outcomes( ) async def update_arms( experiment_id: int, - user_db: UserDB = Depends(authenticate_key), + user_db: UserDB = Depends(authenticate_workspace_key), asession: AsyncSession = Depends(get_async_session), ) -> list[BayesABArmResponse]: """ Get the outcomes for the experiment. """ - workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) + # Get workspace from user context + workspace_id = user_db.current_workspace.workspace_id # Check experiment params experiment = await get_bayes_ab_experiment_by_id( - experiment_id, - user_db.user_id, - workspace_db.workspace_id, - asession + experiment_id, user_db.user_id, workspace_id, asession ) if not experiment: raise HTTPException( @@ -411,10 +394,7 @@ async def validate_experiment_and_draw( ) -> tuple[BayesianABDB, BayesianABDrawDB]: """Validate the experiment and draw""" experiment = await get_bayes_ab_experiment_by_id( - experiment_id, - user_id, - workspace_id, - asession + experiment_id, user_id, workspace_id, asession ) if experiment is None: raise HTTPException( diff --git a/backend/app/contextual_mab/routers.py b/backend/app/contextual_mab/routers.py index 36c237a..5ce8827 100644 --- a/backend/app/contextual_mab/routers.py +++ b/backend/app/contextual_mab/routers.py @@ -5,7 +5,7 @@ from fastapi.exceptions import HTTPException from sqlalchemy.ext.asyncio import AsyncSession -from ..auth.dependencies import authenticate_key, get_verified_user +from ..auth.dependencies import authenticate_workspace_key, get_verified_user from ..database import get_async_session from ..models import get_notifications_from_db, save_notifications_to_db from ..schemas import ( @@ -193,16 +193,17 @@ async def draw_arm( context: List[ContextInput], draw_id: Optional[str] = None, client_id: Optional[str] = None, - user_db: UserDB = Depends(authenticate_key), + user_db: UserDB = Depends(authenticate_workspace_key), asession: AsyncSession = Depends(get_async_session), ) -> CMABDrawResponse: """ Get which arm to pull next for provided experiment. """ - workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) + # Get workspace from user context + workspace_id = user_db.current_workspace.workspace_id experiment = await get_contextual_mab_by_id( - experiment_id, user_db.user_id, workspace_db.workspace_id, asession + experiment_id, user_db.user_id, workspace_id, asession ) if experiment is None: @@ -296,18 +297,19 @@ async def update_arm( experiment_id: int, draw_id: str, reward: float, - user_db: UserDB = Depends(authenticate_key), + user_db: UserDB = Depends(authenticate_workspace_key), asession: AsyncSession = Depends(get_async_session), ) -> ContextualArmResponse: """ Update the arm with the provided `arm_id` for the given `experiment_id` based on the reward. """ - workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) + # Get workspace from user context + workspace_id = user_db.current_workspace.workspace_id # Get the experiment and do checks experiment, draw = await validate_experiment_and_draw( - experiment_id, draw_id, user_db.user_id, asession + experiment_id, draw_id, user_db.user_id, workspace_id, asession ) return await update_based_on_outcome( @@ -321,16 +323,17 @@ async def update_arm( ) async def get_outcomes( experiment_id: int, - user_db: UserDB = Depends(authenticate_key), + user_db: UserDB = Depends(authenticate_workspace_key), asession: AsyncSession = Depends(get_async_session), ) -> list[CMABObservationResponse]: """ Get the outcomes for the experiment. """ - workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) + # Get workspace from user context + workspace_id = user_db.current_workspace.workspace_id experiment = await get_contextual_mab_by_id( - experiment_id, user_db.user_id, workspace_db.workspace_id, asession + experiment_id, user_db.user_id, workspace_id, asession ) if not experiment: raise HTTPException( @@ -346,10 +349,16 @@ async def get_outcomes( async def validate_experiment_and_draw( - experiment_id: int, draw_id: str, user_id: int, asession: AsyncSession + experiment_id: int, + draw_id: str, + user_id: int, + workspace_id: int, + asession: AsyncSession, ) -> tuple[ContextualBanditDB, ContextualDrawDB]: """Validate the experiment and draw""" - experiment = await get_contextual_mab_by_id(experiment_id, user_id, asession) + experiment = await get_contextual_mab_by_id( + experiment_id, user_id, workspace_id, asession + ) if experiment is None: raise HTTPException( status_code=404, detail=f"Experiment with id {experiment_id} not found" diff --git a/backend/app/mab/routers.py b/backend/app/mab/routers.py index 360b21c..8354dee 100644 --- a/backend/app/mab/routers.py +++ b/backend/app/mab/routers.py @@ -5,7 +5,7 @@ from fastapi.exceptions import HTTPException from sqlalchemy.ext.asyncio import AsyncSession -from ..auth.dependencies import authenticate_key, get_verified_user +from ..auth.dependencies import authenticate_workspace_key, get_verified_user from ..database import get_async_session from ..models import get_notifications_from_db, save_notifications_to_db from ..schemas import NotificationsResponse, ObservationType @@ -190,16 +190,17 @@ async def draw_arm( experiment_id: int, draw_id: Optional[str] = None, client_id: Optional[str] = None, - user_db: UserDB = Depends(authenticate_key), + user_db: UserDB = Depends(authenticate_workspace_key), asession: AsyncSession = Depends(get_async_session), ) -> MABDrawResponse: """ Draw an arm for the provided experiment. """ - workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) + # Get workspace from user context + workspace_id = user_db.current_workspace.workspace_id experiment = await get_mab_by_id( - experiment_id, user_db.user_id, workspace_db.workspace_id, asession + experiment_id, user_db.user_id, workspace_id, asession ) if experiment is None: raise HTTPException( @@ -269,16 +270,18 @@ async def update_arm( experiment_id: int, draw_id: str, outcome: float, - user_db: UserDB = Depends(authenticate_key), + user_db: UserDB = Depends(authenticate_workspace_key), asession: AsyncSession = Depends(get_async_session), ) -> ArmResponse: """ Update the arm with the provided `arm_id` for the given `experiment_id` based on the `outcome`. """ - workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) + # Get workspace from user context + workspace_id = user_db.current_workspace.workspace_id + experiment, draw = await validate_experiment_and_draw( - experiment_id, draw_id, user_db.user_id, asession + experiment_id, draw_id, user_db.user_id, workspace_id, asession ) return await update_based_on_outcome( @@ -292,16 +295,17 @@ async def update_arm( ) async def get_outcomes( experiment_id: int, - user_db: UserDB = Depends(authenticate_key), + user_db: UserDB = Depends(authenticate_workspace_key), asession: AsyncSession = Depends(get_async_session), ) -> list[MABObservationResponse]: """ Get the outcomes for the experiment. """ - workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) + # Get workspace from user context + workspace_id = user_db.current_workspace.workspace_id experiment = await get_mab_by_id( - experiment_id, user_db.user_id, workspace_db.workspace_id, asession + experiment_id, user_db.user_id, workspace_id, asession ) if not experiment: raise HTTPException( @@ -318,10 +322,14 @@ async def get_outcomes( async def validate_experiment_and_draw( - experiment_id: int, draw_id: str, user_id: int, asession: AsyncSession + experiment_id: int, + draw_id: str, + user_id: int, + workspace_id: int, + asession: AsyncSession, ) -> tuple[MultiArmedBanditDB, MABDrawDB]: """Validate the experiment and draw""" - experiment = await get_mab_by_id(experiment_id, user_id, asession) + experiment = await get_mab_by_id(experiment_id, user_id, workspace_id, asession) if experiment is None: raise HTTPException( status_code=404, detail=f"Experiment with id {experiment_id} not found" diff --git a/backend/app/workspaces/models.py b/backend/app/workspaces/models.py index 09c5efc..0c3011a 100644 --- a/backend/app/workspaces/models.py +++ b/backend/app/workspaces/models.py @@ -1,5 +1,5 @@ from datetime import datetime, timezone -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, Sequence, cast from sqlalchemy import ( Boolean, @@ -15,12 +15,12 @@ text, update, ) +from sqlalchemy.exc import NoResultFound from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Mapped, mapped_column, relationship -from sqlalchemy.exc import NoResultFound -from ..users.exceptions import UserNotFoundError from ..models import Base, ExperimentBaseDB +from ..users.exceptions import UserNotFoundError from ..users.schemas import UserCreate from .schemas import UserCreateWithCode, UserRoles @@ -98,7 +98,10 @@ class WorkspaceDB(Base): def __repr__(self) -> str: """Define the string representation for the `WorkspaceDB` class.""" - return f"" + return ( + f"" + ) class UserWorkspaceDB(Base): @@ -135,7 +138,10 @@ class UserWorkspaceDB(Base): def __repr__(self) -> str: """Define the string representation for the `UserWorkspaceDB` class.""" - return f"." + return ( + f"." + ) class PendingInvitationDB(Base): @@ -165,7 +171,10 @@ class PendingInvitationDB(Base): ) def __repr__(self) -> str: - return f"" + return ( + f"" + ) class ApiKeyRotationHistoryDB(Base): @@ -196,7 +205,10 @@ class ApiKeyRotationHistoryDB(Base): def __repr__(self) -> str: """Define the string representation.""" - return f"" + return ( + f"" + ) async def get_users_in_workspace( @@ -238,7 +250,8 @@ async def remove_user_from_workspace( if not user_workspace: raise UserNotFoundInWorkspaceError( - f"User '{user_db.username}' not found in workspace '{workspace_db.workspace_name}'." + f"User '{user_db.username}' not found in workspace " + f"'{workspace_db.workspace_name}'." ) # Delete the relationship @@ -451,11 +464,13 @@ async def add_existing_user_to_workspace( asession=asession, user_db=user_db, workspace_db=workspace_db ) + user_role = cast(UserRoles, user.role) + _ = await create_user_workspace_role( asession=asession, is_default_workspace=user.is_default_workspace, user_db=user_db, - user_role=user.role, + user_role=user_role, workspace_db=workspace_db, ) diff --git a/backend/migrations/versions/9f7482ba882f_workspace_model.py b/backend/migrations/versions/9f7482ba882f_workspace_model.py new file mode 100644 index 0000000..3543211 --- /dev/null +++ b/backend/migrations/versions/9f7482ba882f_workspace_model.py @@ -0,0 +1,123 @@ +"""Workspace model + +Revision ID: 9f7482ba882f +Revises: 275ff74c0866 +Create Date: 2025-05-04 11:56:03.939578 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "9f7482ba882f" +down_revision: Union[str, None] = "275ff74c0866" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "workspace", + sa.Column("api_daily_quota", sa.Integer(), nullable=True), + sa.Column("api_key_first_characters", sa.String(length=5), nullable=True), + sa.Column( + "api_key_updated_datetime_utc", sa.DateTime(timezone=True), nullable=True + ), + sa.Column("api_key_rotated_by_user_id", sa.Integer(), nullable=True), + sa.Column("content_quota", sa.Integer(), nullable=True), + sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False), + sa.Column("hashed_api_key", sa.String(length=96), nullable=True), + sa.Column("updated_datetime_utc", sa.DateTime(timezone=True), nullable=False), + sa.Column("workspace_id", sa.Integer(), nullable=False), + sa.Column("workspace_name", sa.String(), nullable=False), + sa.Column("is_default", sa.Boolean(), nullable=False), + sa.ForeignKeyConstraint( + ["api_key_rotated_by_user_id"], + ["users.user_id"], + ), + sa.PrimaryKeyConstraint("workspace_id"), + sa.UniqueConstraint("hashed_api_key"), + sa.UniqueConstraint("workspace_name"), + ) + op.create_table( + "api_key_rotation_history", + sa.Column("rotation_id", sa.Integer(), nullable=False), + sa.Column("workspace_id", sa.Integer(), nullable=False), + sa.Column("rotated_by_user_id", sa.Integer(), nullable=False), + sa.Column("key_first_characters", sa.String(length=5), nullable=False), + sa.Column("rotation_datetime_utc", sa.DateTime(timezone=True), nullable=False), + sa.ForeignKeyConstraint( + ["rotated_by_user_id"], + ["users.user_id"], + ), + sa.ForeignKeyConstraint( + ["workspace_id"], ["workspace.workspace_id"], ondelete="CASCADE" + ), + sa.PrimaryKeyConstraint("rotation_id"), + ) + op.create_table( + "pending_invitations", + sa.Column("invitation_id", sa.Integer(), nullable=False), + sa.Column("email", sa.String(), nullable=False), + sa.Column("workspace_id", sa.Integer(), nullable=False), + sa.Column( + "role", + sa.Enum("ADMIN", "READ_ONLY", name="userroles", native_enum=False), + nullable=False, + ), + sa.Column("inviter_id", sa.Integer(), nullable=False), + sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False), + sa.ForeignKeyConstraint( + ["inviter_id"], + ["users.user_id"], + ), + sa.ForeignKeyConstraint( + ["workspace_id"], ["workspace.workspace_id"], ondelete="CASCADE" + ), + sa.PrimaryKeyConstraint("invitation_id"), + ) + op.create_table( + "user_workspace", + sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False), + sa.Column( + "default_workspace", + sa.Boolean(), + server_default=sa.text("false"), + nullable=False, + ), + sa.Column("updated_datetime_utc", sa.DateTime(timezone=True), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column( + "user_role", + sa.Enum("ADMIN", "READ_ONLY", name="userroles", native_enum=False), + nullable=False, + ), + sa.Column("workspace_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(["user_id"], ["users.user_id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint( + ["workspace_id"], ["workspace.workspace_id"], ondelete="CASCADE" + ), + sa.PrimaryKeyConstraint("user_id", "workspace_id"), + ) + op.add_column( + "experiments_base", sa.Column("workspace_id", sa.Integer(), nullable=False) + ) + op.create_foreign_key( + None, "experiments_base", "workspace", ["workspace_id"], ["workspace_id"] + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, "experiments_base", type_="foreignkey") + op.drop_column("experiments_base", "workspace_id") + op.drop_table("user_workspace") + op.drop_table("pending_invitations") + op.drop_table("api_key_rotation_history") + op.drop_table("workspace") + # ### end Alembic commands ### diff --git a/backend/migrations/versions/aee60a81b8b6_api_keys_history_all.py b/backend/migrations/versions/aee60a81b8b6_api_keys_history_all.py deleted file mode 100644 index b7c5443..0000000 --- a/backend/migrations/versions/aee60a81b8b6_api_keys_history_all.py +++ /dev/null @@ -1,46 +0,0 @@ -"""API keys history all - -Revision ID: aee60a81b8b6 -Revises: ee03b9fceb6f -Create Date: 2025-05-03 01:44:16.809489 - -""" - -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision: str = "aee60a81b8b6" -down_revision: Union[str, None] = "ee03b9fceb6f" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table( - "api_key_rotation_history", - sa.Column("rotation_id", sa.Integer(), nullable=False), - sa.Column("workspace_id", sa.Integer(), nullable=False), - sa.Column("rotated_by_user_id", sa.Integer(), nullable=False), - sa.Column("key_first_characters", sa.String(length=5), nullable=False), - sa.Column("rotation_datetime_utc", sa.DateTime(timezone=True), nullable=False), - sa.ForeignKeyConstraint( - ["rotated_by_user_id"], - ["users.user_id"], - ), - sa.ForeignKeyConstraint( - ["workspace_id"], ["workspace.workspace_id"], ondelete="CASCADE" - ), - sa.PrimaryKeyConstraint("rotation_id"), - ) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_table("api_key_rotation_history") - # ### end Alembic commands ### diff --git a/backend/migrations/versions/d9f7a309944e_workspace_model.py b/backend/migrations/versions/d9f7a309944e_workspace_model.py deleted file mode 100644 index 1482764..0000000 --- a/backend/migrations/versions/d9f7a309944e_workspace_model.py +++ /dev/null @@ -1,72 +0,0 @@ -"""Workspace model - -Revision ID: d9f7a309944e -Revises: 5c15463fda65 -Create Date: 2025-04-30 23:23:21.122138 - -""" -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision: str = 'd9f7a309944e' -down_revision: Union[str, None] = '5c15463fda65' -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table('workspace', - sa.Column('api_daily_quota', sa.Integer(), nullable=True), - sa.Column('api_key_first_characters', sa.String(length=5), nullable=True), - sa.Column('api_key_updated_datetime_utc', sa.DateTime(timezone=True), nullable=True), - sa.Column('content_quota', sa.Integer(), nullable=True), - sa.Column('created_datetime_utc', sa.DateTime(timezone=True), nullable=False), - sa.Column('hashed_api_key', sa.String(length=96), nullable=True), - sa.Column('updated_datetime_utc', sa.DateTime(timezone=True), nullable=False), - sa.Column('workspace_id', sa.Integer(), nullable=False), - sa.Column('workspace_name', sa.String(), nullable=False), - sa.Column('is_default', sa.Boolean(), nullable=False), - sa.PrimaryKeyConstraint('workspace_id'), - sa.UniqueConstraint('hashed_api_key'), - sa.UniqueConstraint('workspace_name') - ) - op.create_table('pending_invitations', - sa.Column('invitation_id', sa.Integer(), nullable=False), - sa.Column('email', sa.String(), nullable=False), - sa.Column('workspace_id', sa.Integer(), nullable=False), - sa.Column('role', sa.Enum('ADMIN', 'READ_ONLY', name='userroles', native_enum=False), nullable=False), - sa.Column('inviter_id', sa.Integer(), nullable=False), - sa.Column('created_datetime_utc', sa.DateTime(timezone=True), nullable=False), - sa.ForeignKeyConstraint(['inviter_id'], ['users.user_id'], ), - sa.ForeignKeyConstraint(['workspace_id'], ['workspace.workspace_id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('invitation_id') - ) - op.create_table('user_workspace', - sa.Column('created_datetime_utc', sa.DateTime(timezone=True), nullable=False), - sa.Column('default_workspace', sa.Boolean(), server_default=sa.text('false'), nullable=False), - sa.Column('updated_datetime_utc', sa.DateTime(timezone=True), nullable=False), - sa.Column('user_id', sa.Integer(), nullable=False), - sa.Column('user_role', sa.Enum('ADMIN', 'READ_ONLY', name='userroles', native_enum=False), nullable=False), - sa.Column('workspace_id', sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(['user_id'], ['users.user_id'], ondelete='CASCADE'), - sa.ForeignKeyConstraint(['workspace_id'], ['workspace.workspace_id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('user_id', 'workspace_id') - ) - op.add_column('experiments_base', sa.Column('workspace_id', sa.Integer(), nullable=False)) - op.create_foreign_key(None, 'experiments_base', 'workspace', ['workspace_id'], ['workspace_id']) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_constraint(None, 'experiments_base', type_='foreignkey') - op.drop_column('experiments_base', 'workspace_id') - op.drop_table('user_workspace') - op.drop_table('pending_invitations') - op.drop_table('workspace') - # ### end Alembic commands ### diff --git a/backend/migrations/versions/ee03b9fceb6f_api_keys_history.py b/backend/migrations/versions/ee03b9fceb6f_api_keys_history.py deleted file mode 100644 index 42c49ad..0000000 --- a/backend/migrations/versions/ee03b9fceb6f_api_keys_history.py +++ /dev/null @@ -1,38 +0,0 @@ -"""API keys history - -Revision ID: ee03b9fceb6f -Revises: d9f7a309944e -Create Date: 2025-05-03 01:13:44.071704 - -""" - -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision: str = "ee03b9fceb6f" -down_revision: Union[str, None] = "d9f7a309944e" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.add_column( - "workspace", - sa.Column("api_key_rotated_by_user_id", sa.Integer(), nullable=True), - ) - op.create_foreign_key( - None, "workspace", "users", ["api_key_rotated_by_user_id"], ["user_id"] - ) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_constraint(None, "workspace", type_="foreignkey") - op.drop_column("workspace", "api_key_rotated_by_user_id") - # ### end Alembic commands ### diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 2c3784c..1c59d27 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -1,10 +1,11 @@ import os -from datetime import datetime +import uuid +from datetime import UTC, datetime from typing import AsyncGenerator, Generator import pytest from fastapi.testclient import TestClient -from sqlalchemy import select +from sqlalchemy import select, text from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine from sqlalchemy.orm import Session @@ -15,6 +16,7 @@ ) from backend.app.users.models import UserDB from backend.app.utils import get_key_hash, get_password_salted_hash +from backend.app.workspaces.models import UserRoles, UserWorkspaceDB, WorkspaceDB from .config import ( TEST_API_QUOTA, @@ -83,42 +85,132 @@ def client() -> Generator[TestClient, None, None]: @pytest.fixture(scope="function") def regular_user(client: TestClient, db_session: Session) -> Generator: + # Create a unique username and API key for each test run to avoid conflicts + unique_id = str(uuid.uuid4())[:8] + unique_username = f"{TEST_USERNAME}_{unique_id}" + unique_api_key = f"{TEST_USER_API_KEY}_{unique_id}" + + # Create user regular_user = UserDB( - username=TEST_USERNAME, + username=unique_username, hashed_password=get_password_salted_hash(TEST_PASSWORD), first_name=TEST_FIRST_NAME, last_name=TEST_LAST_NAME, - hashed_api_key=get_key_hash(TEST_USER_API_KEY), - api_key_first_characters=TEST_USER_API_KEY[:5], - api_key_updated_datetime_utc=datetime.utcnow(), + hashed_api_key=get_key_hash(unique_api_key), + api_key_first_characters=unique_api_key[:5], + api_key_updated_datetime_utc=datetime.now(UTC), experiments_quota=TEST_EXPERIMENTS_QUOTA, api_daily_quota=TEST_API_QUOTA, - created_datetime_utc=datetime.utcnow(), - updated_datetime_utc=datetime.utcnow(), + created_datetime_utc=datetime.now(UTC), + updated_datetime_utc=datetime.now(UTC), + is_verified=True, # Make user verified for testing ) db_session.add(regular_user) db_session.commit() - yield regular_user.user_id - db_session.delete(regular_user) + # Create default workspace for user with a unique workspace API key + unique_workspace_api_key = f"workspace_{unique_id}" + default_workspace = WorkspaceDB( + workspace_name=f"{unique_username}'s Workspace", + api_daily_quota=TEST_API_QUOTA, + content_quota=TEST_EXPERIMENTS_QUOTA, + created_datetime_utc=datetime.now(UTC), + updated_datetime_utc=datetime.now(UTC), + is_default=True, + hashed_api_key=get_key_hash(unique_workspace_api_key), + api_key_first_characters=unique_workspace_api_key[:5], + api_key_updated_datetime_utc=datetime.now(UTC), + api_key_rotated_by_user_id=regular_user.user_id, + ) + + db_session.add(default_workspace) db_session.commit() + # Create user workspace relationship + user_workspace = UserWorkspaceDB( + user_id=regular_user.user_id, + workspace_id=default_workspace.workspace_id, + user_role=UserRoles.ADMIN, + default_workspace=True, + created_datetime_utc=datetime.now(UTC), + updated_datetime_utc=datetime.now(UTC), + ) + + db_session.add(user_workspace) + db_session.commit() + + yield regular_user.user_id, unique_username, unique_api_key + + # Clean up - need to handle foreign key relationships properly + try: + # 1. Clean up pending invitations that reference this user as inviter + db_session.execute( + text( + "DELETE FROM pending_invitations WHERE inviter_id = " + f"{regular_user.user_id}" + ) + ) + db_session.commit() + + # 2. Clean up API key rotation history records that reference this user + db_session.execute( + text( + "DELETE FROM api_key_rotation_history WHERE rotated_by_user_id = " + f"{regular_user.user_id}" + ) + ) + db_session.commit() + + # 3. Remove the user-workspace relationship + db_session.query(UserWorkspaceDB).filter( + UserWorkspaceDB.user_id == regular_user.user_id + ).delete() + db_session.commit() + + # 4. Remove the reference from workspace.api_key_rotated_by_user_id + db_session.query(WorkspaceDB).filter( + WorkspaceDB.api_key_rotated_by_user_id == regular_user.user_id + ).update({WorkspaceDB.api_key_rotated_by_user_id: None}) + db_session.commit() + + # 5. Now delete the workspace + db_session.query(WorkspaceDB).filter( + WorkspaceDB.workspace_name == f"{unique_username}'s Workspace" + ).delete() + db_session.commit() + + # 6. Finally delete the user + db_session.delete(regular_user) + db_session.commit() + except Exception as e: + # Log the error but don't fail the test + print(f"Error during cleanup: {e}") + db_session.rollback() + @pytest.fixture(scope="session") def user1(client: TestClient, db_session: Session) -> Generator: stmt = select(UserDB).where(UserDB.username == TEST_USERNAME) result = db_session.execute(stmt) - user = result.scalar_one() - yield user.user_id + try: + user = result.scalar_one() + yield user.user_id + except Exception: + # Handle the case where the user doesn't exist + yield None @pytest.fixture(scope="session") def user2(client: TestClient, db_session: Session) -> Generator: stmt = select(UserDB).where(UserDB.username == TEST_USERNAME_2) result = db_session.execute(stmt) - user = result.scalar_one() - yield user.user_id + try: + user = result.scalar_one() + yield user.user_id + except Exception: + # Handle the case where the user doesn't exist + yield None @pytest.fixture(scope="session") diff --git a/backend/tests/pytest.ini b/backend/tests/pytest.ini new file mode 100644 index 0000000..22932f8 --- /dev/null +++ b/backend/tests/pytest.ini @@ -0,0 +1,5 @@ +[pytest] +asyncio_mode = auto + +# Set the default fixture loop scope to function +asyncio_default_fixture_loop_scope = function diff --git a/backend/tests/test_auto_fail.py b/backend/tests/test_auto_fail.py index 6016bc3..82cf4e5 100644 --- a/backend/tests/test_auto_fail.py +++ b/backend/tests/test_auto_fail.py @@ -132,6 +132,41 @@ def now(cls, *arg: list) -> datetime: return mydatetime +@fixture +def admin_token(client: TestClient) -> str: + """Get an admin token for authentication""" + response = client.post( + "/login", + data={ + "username": os.environ.get("ADMIN_USERNAME", ""), + "password": os.environ.get("ADMIN_PASSWORD", ""), + }, + ) + token = response.json()["access_token"] + return token + + +@fixture +def workspace_api_key(client: TestClient, admin_token: str) -> str: + """Get the current workspace API key for testing""" + # Get the current workspace + response = client.get( + "/workspace/current", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert response.status_code == 200 + + # Rotate the workspace API key to get a fresh one + response = client.put( + "/workspace/rotate-key", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert response.status_code == 200 + workspace_api_key = response.json()["new_api_key"] + + return workspace_api_key + + class TestMABAutoFailJob: @fixture def create_mab_with_autofail( @@ -155,7 +190,7 @@ def create_mab_with_autofail( mab = response.json() yield mab headers = {"Authorization": f"Bearer {admin_token}"} - client.delete(f"/mab/{['experiment_id']}", headers=headers) + client.delete(f"/mab/{mab['experiment_id']}", headers=headers) @mark.parametrize( "create_mab_with_autofail, fail_value, fail_unit, n_observed", @@ -177,10 +212,10 @@ async def test_auto_fail_job( fail_unit: Literal["days", "hours"], n_observed: int, asession: AsyncSession, + workspace_api_key: str, ) -> None: draws = [] - api_key = os.environ.get("ADMIN_API_KEY", "") - headers = {"Authorization": f"Bearer {api_key}"} + headers = {"Authorization": f"Bearer {workspace_api_key}"} for i in range(1, 15): monkeypatch.setattr( mab_models, @@ -232,7 +267,7 @@ def create_bayes_ab_with_autofail( ab = response.json() yield ab headers = {"Authorization": f"Bearer {admin_token}"} - client.delete(f"/bayes_ab/{['experiment_id']}", headers=headers) + client.delete(f"/bayes_ab/{ab['experiment_id']}", headers=headers) @mark.parametrize( "create_bayes_ab_with_autofail, fail_value, fail_unit, n_observed", @@ -254,10 +289,10 @@ async def test_auto_fail_job( fail_unit: Literal["days", "hours"], n_observed: int, asession: AsyncSession, + workspace_api_key: str, ) -> None: draws = [] - api_key = os.environ.get("ADMIN_API_KEY", "") - headers = {"Authorization": f"Bearer {api_key}"} + headers = {"Authorization": f"Bearer {workspace_api_key}"} for i in range(1, 15): monkeypatch.setattr( bayes_ab_models, @@ -309,7 +344,7 @@ def create_cmab_with_autofail( cmab = response.json() yield cmab headers = {"Authorization": f"Bearer {admin_token}"} - client.delete(f"/contextual_mab/{['experiment_id']}", headers=headers) + client.delete(f"/contextual_mab/{cmab['experiment_id']}", headers=headers) @mark.parametrize( "create_cmab_with_autofail, fail_value, fail_unit, n_observed", @@ -331,10 +366,10 @@ async def test_auto_fail_job( fail_unit: Literal["days", "hours"], n_observed: int, asession: AsyncSession, + workspace_api_key: str, ) -> None: draws = [] - api_key = os.environ.get("ADMIN_API_KEY", "") - headers = {"Authorization": f"Bearer {api_key}"} + headers = {"Authorization": f"Bearer {workspace_api_key}"} for i in range(1, 15): monkeypatch.setattr( cmab_models, @@ -347,8 +382,8 @@ async def test_auto_fail_job( response = client.post( f"/contextual_mab/{create_cmab_with_autofail['experiment_id']}/draw", json=[ - {"context_id": 0, "context_value": 0}, {"context_id": 1, "context_value": 0}, + {"context_id": 2, "context_value": 0}, ], headers=headers, ) diff --git a/backend/tests/test_bayes_ab.py b/backend/tests/test_bayes_ab.py index 044d41b..42f8e7a 100644 --- a/backend/tests/test_bayes_ab.py +++ b/backend/tests/test_bayes_ab.py @@ -59,6 +59,42 @@ def clean_bayes_ab(db_session: Session) -> Generator: db_session.commit() +@fixture +def admin_token(client: TestClient) -> str: + """Get a token for the admin user""" + response = client.post( + "/login", + data={ + "username": os.environ.get("ADMIN_USERNAME", "admin@idinsight.org"), + "password": os.environ.get("ADMIN_PASSWORD", "12345"), + }, + ) + assert response.status_code == 200, f"Login failed: {response.json()}" + token = response.json()["access_token"] + return token + + +@fixture +def workspace_api_key(client: TestClient, admin_token: str) -> str: + """Get the current workspace API key for testing""" + # Get the current workspace + response = client.get( + "/workspace/current", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert response.status_code == 200 + + # Rotate the workspace API key to get a fresh one + response = client.put( + "/workspace/rotate-key", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert response.status_code == 200 + workspace_api_key = response.json()["new_api_key"] + + return workspace_api_key + + class TestBayesAB: """ Test class for Bayesian A/B testing. @@ -195,12 +231,12 @@ def test_draw_arm( create_bayes_abs: list, create_bayes_ab_payload: dict, expected_response: int, + workspace_api_key: str, ) -> None: id = create_bayes_abs[0]["experiment_id"] - api_key = os.environ.get("ADMIN_API_KEY", "") response = client.get( f"/bayes_ab/{id}/draw", - headers={"Authorization": f"Bearer {api_key}"}, + headers={"Authorization": f"Bearer {workspace_api_key}"}, ) assert response.status_code == expected_response @@ -219,12 +255,12 @@ def test_draw_arm_with_client_id( create_bayes_ab_payload: dict, client_id: str | None, expected_response: int, + workspace_api_key: str, ) -> None: id = create_bayes_abs[0]["experiment_id"] - api_key = os.environ.get("ADMIN_API_KEY", "") response = client.get( f"/bayes_ab/{id}/draw{'?client_id=' + client_id if client_id else ''}", - headers={"Authorization": f"Bearer {api_key}"}, + headers={"Authorization": f"Bearer {workspace_api_key}"}, ) assert response.status_code == expected_response @@ -232,19 +268,120 @@ def test_draw_arm_with_client_id( "create_bayes_ab_payload", ["with_sticky_assignment"], indirect=True ) def test_draw_arm_with_sticky_assignment( - self, client: TestClient, create_bayes_abs: list, create_bayes_ab_payload: dict + self, + client: TestClient, + create_bayes_abs: list, + create_bayes_ab_payload: dict, + workspace_api_key: str, ) -> None: id = create_bayes_abs[0]["experiment_id"] - api_key = os.environ.get("ADMIN_API_KEY", "") arm_ids = [] for _ in range(10): response = client.get( f"/bayes_ab/{id}/draw?client_id=123", - headers={"Authorization": f"Bearer {api_key}"}, + headers={"Authorization": f"Bearer {workspace_api_key}"}, ) arm_ids.append(response.json()["arm"]["arm_id"]) assert np.unique(arm_ids).size == 1 + @mark.parametrize("create_bayes_ab_payload", ["base_normal"], indirect=True) + def test_update_observation( + self, + client: TestClient, + create_bayes_abs: list, + create_bayes_ab_payload: dict, + workspace_api_key: str, + ) -> None: + id = create_bayes_abs[0]["experiment_id"] + + # First, get a draw + response = client.get( + f"/bayes_ab/{id}/draw", + headers={"Authorization": f"Bearer {workspace_api_key}"}, + ) + assert response.status_code == 200 + draw_id = response.json()["draw_id"] + + # Then update with an observation + response = client.put( + f"/bayes_ab/{id}/{draw_id}/0.5", + headers={"Authorization": f"Bearer {workspace_api_key}"}, + ) + assert response.status_code == 200 + + # Test that we can't update the same draw twice + response = client.put( + f"/bayes_ab/{id}/{draw_id}/0.5", + headers={"Authorization": f"Bearer {workspace_api_key}"}, + ) + assert response.status_code == 400 + + @mark.parametrize("create_bayes_ab_payload", ["base_normal"], indirect=True) + def test_get_outcomes( + self, + client: TestClient, + create_bayes_abs: list, + create_bayes_ab_payload: dict, + workspace_api_key: str, + ) -> None: + id = create_bayes_abs[0]["experiment_id"] + + # First, get a draw + response = client.get( + f"/bayes_ab/{id}/draw", + headers={"Authorization": f"Bearer {workspace_api_key}"}, + ) + assert response.status_code == 200 + draw_id = response.json()["draw_id"] + + # Then update with an observation + response = client.put( + f"/bayes_ab/{id}/{draw_id}/0.5", + headers={"Authorization": f"Bearer {workspace_api_key}"}, + ) + assert response.status_code == 200 + + # Get outcomes + response = client.get( + f"/bayes_ab/{id}/outcomes", + headers={"Authorization": f"Bearer {workspace_api_key}"}, + ) + assert response.status_code == 200 + assert len(response.json()) == 1 + + @mark.parametrize("create_bayes_ab_payload", ["base_normal"], indirect=True) + def test_get_arms( + self, + client: TestClient, + create_bayes_abs: list, + create_bayes_ab_payload: dict, + workspace_api_key: str, + ) -> None: + id = create_bayes_abs[0]["experiment_id"] + + # First, get a draw + response = client.get( + f"/bayes_ab/{id}/draw", + headers={"Authorization": f"Bearer {workspace_api_key}"}, + ) + assert response.status_code == 200 + draw_id = response.json()["draw_id"] + + # Then update with an observation + response = client.put( + f"/bayes_ab/{id}/{draw_id}/0.5", + headers={"Authorization": f"Bearer {workspace_api_key}"}, + ) + assert response.status_code == 200 + + # Get arms + response = client.get( + f"/bayes_ab/{id}/arms", + headers={"Authorization": f"Bearer {workspace_api_key}"}, + ) + assert response.status_code == 200 + assert len(response.json()) == 2 + class TestNotifications: @fixture() diff --git a/backend/tests/test_cmabs.py b/backend/tests/test_cmabs.py index 1a19cd0..406c14a 100644 --- a/backend/tests/test_cmabs.py +++ b/backend/tests/test_cmabs.py @@ -60,6 +60,40 @@ base_binary_normal_payload["reward_type"] = "binary" +@fixture +def admin_token(client: TestClient) -> str: + response = client.post( + "/login", + data={ + "username": os.environ.get("ADMIN_USERNAME", ""), + "password": os.environ.get("ADMIN_PASSWORD", ""), + }, + ) + token = response.json()["access_token"] + return token + + +@fixture +def workspace_api_key(client: TestClient, admin_token: str) -> str: + """Get the current workspace API key for testing""" + # Get the current workspace + response = client.get( + "/workspace/current", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert response.status_code == 200 + + # Rotate the workspace API key to get a fresh one + response = client.put( + "/workspace/rotate-key", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert response.status_code == 200 + workspace_api_key = response.json()["new_api_key"] + + return workspace_api_key + + @fixture def clean_cmabs(db_session: Session) -> Generator: yield @@ -201,13 +235,16 @@ def test_get_cmab( @mark.parametrize("create_cmab_payload", ["base_normal"], indirect=True) def test_draw_arm_draw_id_provided( - self, client: TestClient, create_cmabs: list, create_cmab_payload: dict + self, + client: TestClient, + create_cmabs: list, + create_cmab_payload: dict, + workspace_api_key: str, ) -> None: id = create_cmabs[0]["experiment_id"] - api_key = os.environ.get("ADMIN_API_KEY", "") response = client.post( f"/contextual_mab/{id}/draw", - headers={"Authorization": f"Bearer {api_key}"}, + headers={"Authorization": f"Bearer {workspace_api_key}"}, params={"draw_id": "test_draw_id"}, json=[ {"context_id": 1, "context_value": 0}, @@ -219,13 +256,16 @@ def test_draw_arm_draw_id_provided( @mark.parametrize("create_cmab_payload", ["base_normal"], indirect=True) def test_draw_arm_no_draw_id_provided( - self, client: TestClient, create_cmabs: list, create_cmab_payload: dict + self, + client: TestClient, + create_cmabs: list, + create_cmab_payload: dict, + workspace_api_key: str, ) -> None: id = create_cmabs[0]["experiment_id"] - api_key = os.environ.get("ADMIN_API_KEY", "") response = client.post( f"/contextual_mab/{id}/draw", - headers={"Authorization": f"Bearer {api_key}"}, + headers={"Authorization": f"Bearer {workspace_api_key}"}, json=[ {"context_id": 1, "context_value": 0}, {"context_id": 2, "context_value": 0.5}, @@ -249,16 +289,16 @@ def test_draw_arm_sticky_assignment_client_id_provided( create_cmab_payload: dict, client_id: str | None, expected_response: int, + workspace_api_key: str, ) -> None: id = create_cmabs[0]["experiment_id"] - api_key = os.environ.get("ADMIN_API_KEY", "") url = f"/contextual_mab/{id}/draw" if client_id: url += f"?client_id={client_id}" response = client.post( url, - headers={"Authorization": f"Bearer {api_key}"}, + headers={"Authorization": f"Bearer {workspace_api_key}"}, json=[ {"context_id": 1, "context_value": 0}, {"context_id": 2, "context_value": 0.5}, @@ -268,16 +308,19 @@ def test_draw_arm_sticky_assignment_client_id_provided( @mark.parametrize("create_cmab_payload", ["with_sticky_assignment"], indirect=True) def test_draw_arm_with_sticky_assignment( - self, client: TestClient, create_cmabs: list, create_cmab_payload: dict + self, + client: TestClient, + create_cmabs: list, + create_cmab_payload: dict, + workspace_api_key: str, ) -> None: id = create_cmabs[0]["experiment_id"] - api_key = os.environ.get("ADMIN_API_KEY", "") arm_ids = [] for _ in range(10): response = client.post( f"/contextual_mab/{id}/draw?client_id=123", - headers={"Authorization": f"Bearer {api_key}"}, + headers={"Authorization": f"Bearer {workspace_api_key}"}, json=[ {"context_id": 1, "context_value": 0}, {"context_id": 2, "context_value": 1}, @@ -289,13 +332,16 @@ def test_draw_arm_with_sticky_assignment( @mark.parametrize("create_cmab_payload", ["base_normal"], indirect=True) def test_one_outcome_per_draw( - self, client: TestClient, create_cmabs: list, create_cmab_payload: dict + self, + client: TestClient, + create_cmabs: list, + create_cmab_payload: dict, + workspace_api_key: str, ) -> None: id = create_cmabs[0]["experiment_id"] - api_key = os.environ.get("ADMIN_API_KEY", "") response = client.post( f"/contextual_mab/{id}/draw", - headers={"Authorization": f"Bearer {api_key}"}, + headers={"Authorization": f"Bearer {workspace_api_key}"}, json=[ {"context_id": 1, "context_value": 0}, {"context_id": 2, "context_value": 0.5}, @@ -306,14 +352,14 @@ def test_one_outcome_per_draw( response = client.put( f"/contextual_mab/{id}/{draw_id}/1", - headers={"Authorization": f"Bearer {api_key}"}, + headers={"Authorization": f"Bearer {workspace_api_key}"}, ) assert response.status_code == 200 response = client.put( f"/contextual_mab/{id}/{draw_id}/1", - headers={"Authorization": f"Bearer {api_key}"}, + headers={"Authorization": f"Bearer {workspace_api_key}"}, ) assert response.status_code == 400 @@ -329,15 +375,14 @@ def test_get_outcomes( create_cmabs: list, n_draws: int, create_cmab_payload: dict, + workspace_api_key: str, ) -> None: id = create_cmabs[0]["experiment_id"] - api_key = os.environ.get("ADMIN_API_KEY", "") - id = create_cmabs[0]["experiment_id"] for _ in range(n_draws): response = client.post( f"/contextual_mab/{id}/draw", - headers={"Authorization": f"Bearer {api_key}"}, + headers={"Authorization": f"Bearer {workspace_api_key}"}, json=[ {"context_id": 1, "context_value": 0}, {"context_id": 2, "context_value": 0.5}, @@ -347,12 +392,12 @@ def test_get_outcomes( draw_id = response.json()["draw_id"] response = client.put( f"/contextual_mab/{id}/{draw_id}/1", - headers={"Authorization": f"Bearer {api_key}"}, + headers={"Authorization": f"Bearer {workspace_api_key}"}, ) response = client.get( f"/contextual_mab/{id}/outcomes", - headers={"Authorization": f"Bearer {api_key}"}, + headers={"Authorization": f"Bearer {workspace_api_key}"}, ) assert response.status_code == 200 diff --git a/backend/tests/test_mabs.py b/backend/tests/test_mabs.py index b1d0fa8..e9612ca 100644 --- a/backend/tests/test_mabs.py +++ b/backend/tests/test_mabs.py @@ -72,6 +72,27 @@ def admin_token(client: TestClient) -> str: return token +@fixture +def workspace_api_key(client: TestClient, admin_token: str, db_session: Session) -> str: + """Get the current workspace API key for testing""" + # Get the current workspace + response = client.get( + "/workspace/current", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert response.status_code == 200 + + # Rotate the workspace API key to get a fresh one + response = client.put( + "/workspace/rotate-key", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert response.status_code == 200 + workspace_api_key = response.json()["new_api_key"] + + return workspace_api_key + + @fixture def clean_mabs(db_session: Session) -> Generator: yield @@ -233,27 +254,33 @@ def test_get_mab( @mark.parametrize("create_mab_payload", ["base_beta_binom"], indirect=True) def test_draw_arm_draw_id_provided( - self, client: TestClient, create_mabs: list, create_mab_payload: dict + self, + client: TestClient, + create_mabs: list, + create_mab_payload: dict, + workspace_api_key: str, ) -> None: id = create_mabs[0]["experiment_id"] - api_key = os.environ.get("ADMIN_API_KEY", "") response = client.get( f"/mab/{id}/draw", params={"draw_id": "test_draw"}, - headers={"Authorization": f"Bearer {api_key}"}, + headers={"Authorization": f"Bearer {workspace_api_key}"}, ) assert response.status_code == 200 assert response.json()["draw_id"] == "test_draw" @mark.parametrize("create_mab_payload", ["base_beta_binom"], indirect=True) def test_draw_arm_no_draw_id_provided( - self, client: TestClient, create_mabs: list, create_mab_payload: dict + self, + client: TestClient, + create_mabs: list, + create_mab_payload: dict, + workspace_api_key: str, ) -> None: id = create_mabs[0]["experiment_id"] - api_key = os.environ.get("ADMIN_API_KEY", "") response = client.get( f"/mab/{id}/draw", - headers={"Authorization": f"Bearer {api_key}"}, + headers={"Authorization": f"Bearer {workspace_api_key}"}, ) assert response.status_code == 200 assert len(response.json()["draw_id"]) == 36 @@ -274,13 +301,13 @@ def test_draw_arm_sticky_assignment_with_client_id( create_mabs: list, client_id: str | None, expected_response: int, + workspace_api_key: str, ) -> None: mabs = create_mabs id = mabs[0]["experiment_id"] - api_key = os.environ.get("ADMIN_API_KEY", "") response = client.get( f"/mab/{id}/draw{'?client_id=' + client_id if client_id else ''}", - headers={"Authorization": f"Bearer {api_key}"}, + headers={"Authorization": f"Bearer {workspace_api_key}"}, ) assert response.status_code == expected_response @@ -291,13 +318,13 @@ def test_draw_arm_sticky_assignment_client_id_provided( admin_token: str, create_mab_payload: dict, create_mabs: list, + workspace_api_key: str, ) -> None: mabs = create_mabs id = mabs[0]["experiment_id"] - api_key = os.environ.get("ADMIN_API_KEY", "") response = client.get( f"/mab/{id}/draw?client_id=123", - headers={"Authorization": f"Bearer {api_key}"}, + headers={"Authorization": f"Bearer {workspace_api_key}"}, ) assert response.status_code == 200 @@ -308,43 +335,46 @@ def test_draw_arm_sticky_assignment_similar_arms( admin_token: str, create_mab_payload: dict, create_mabs: list, + workspace_api_key: str, ) -> None: mabs = create_mabs id = mabs[0]["experiment_id"] - api_key = os.environ.get("ADMIN_API_KEY", "") arm_ids = [] for _ in range(10): response = client.get( f"/mab/{id}/draw?client_id=123", - headers={"Authorization": f"Bearer {api_key}"}, + headers={"Authorization": f"Bearer {workspace_api_key}"}, ) arm_ids.append(response.json()["arm"]["arm_id"]) assert np.unique(arm_ids).size == 1 @mark.parametrize("create_mab_payload", ["base_beta_binom"], indirect=True) def test_one_outcome_per_draw( - self, client: TestClient, create_mabs: list, create_mab_payload: dict + self, + client: TestClient, + create_mabs: list, + create_mab_payload: dict, + workspace_api_key: str, ) -> None: id = create_mabs[0]["experiment_id"] - api_key = os.environ.get("ADMIN_API_KEY", "") response = client.get( f"/mab/{id}/draw", - headers={"Authorization": f"Bearer {api_key}"}, + headers={"Authorization": f"Bearer {workspace_api_key}"}, ) assert response.status_code == 200 draw_id = response.json()["draw_id"] response = client.put( f"/mab/{id}/{draw_id}/1", - headers={"Authorization": f"Bearer {api_key}"}, + headers={"Authorization": f"Bearer {workspace_api_key}"}, ) assert response.status_code == 200 response = client.put( f"/mab/{id}/{draw_id}/1", - headers={"Authorization": f"Bearer {api_key}"}, + headers={"Authorization": f"Bearer {workspace_api_key}"}, ) assert response.status_code == 400 @@ -360,27 +390,26 @@ def test_get_outcomes( create_mabs: list, n_draws: int, create_mab_payload: dict, + workspace_api_key: str, ) -> None: id = create_mabs[0]["experiment_id"] - api_key = os.environ.get("ADMIN_API_KEY", "") - id = create_mabs[0]["experiment_id"] for _ in range(n_draws): response = client.get( f"/mab/{id}/draw", - headers={"Authorization": f"Bearer {api_key}"}, + headers={"Authorization": f"Bearer {workspace_api_key}"}, ) assert response.status_code == 200 draw_id = response.json()["draw_id"] # put outcomes response = client.put( f"/mab/{id}/{draw_id}/1", - headers={"Authorization": f"Bearer {api_key}"}, + headers={"Authorization": f"Bearer {workspace_api_key}"}, ) response = client.get( f"/mab/{id}/outcomes", - headers={"Authorization": f"Bearer {api_key}"}, + headers={"Authorization": f"Bearer {workspace_api_key}"}, ) assert response.status_code == 200 diff --git a/backend/tests/test_messages.py b/backend/tests/test_messages.py index 86396b4..782ba2c 100644 --- a/backend/tests/test_messages.py +++ b/backend/tests/test_messages.py @@ -50,6 +50,27 @@ def admin_token(client: TestClient) -> str: return token +@fixture +def workspace_api_key(client: TestClient, admin_token: str) -> str: + """Get the current workspace API key for testing""" + # Get the current workspace + response = client.get( + "/workspace/current", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert response.status_code == 200 + + # Rotate the workspace API key to get a fresh one + response = client.put( + "/workspace/rotate-key", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert response.status_code == 200 + workspace_api_key = response.json()["new_api_key"] + + return workspace_api_key + + @fixture def experiment_id(client: TestClient, admin_token: str) -> Generator[int, None, None]: response = client.post( diff --git a/backend/tests/test_notifications_job.py b/backend/tests/test_notifications_job.py index 91d077c..6856920 100644 --- a/backend/tests/test_notifications_job.py +++ b/backend/tests/test_notifications_job.py @@ -53,6 +53,7 @@ def now(cls, *arg: list) -> datetime: @fixture def admin_token(client: TestClient) -> str: + """Get an admin token for authentication""" response = client.post( "/login", data={ @@ -64,6 +65,27 @@ def admin_token(client: TestClient) -> str: return token +@fixture +def workspace_api_key(client: TestClient, admin_token: str) -> str: + """Get the current workspace API key for testing""" + # Get the current workspace + response = client.get( + "/workspace/current", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert response.status_code == 200 + + # Rotate the workspace API key to get a fresh one + response = client.put( + "/workspace/rotate-key", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert response.status_code == 200 + workspace_api_key = response.json()["new_api_key"] + + return workspace_api_key + + class TestNotificationsJob: @fixture def create_mabs_days_elapsed( @@ -174,24 +196,25 @@ async def test_trials_run_notification( create_mabs_trials_run: list[dict], db_session: Session, asession: AsyncSession, + workspace_api_key: str, ) -> None: n_processed = await process_notifications(asession) assert n_processed == 0 - api_key = os.environ.get("ADMIN_API_KEY", "") + headers = {"Authorization": f"Bearer {workspace_api_key}"} for mab in create_mabs_trials_run: for i in range(n_trials): draw_id = f"draw_{i}_{mab['experiment_id']}" response = client.get( f"/mab/{mab['experiment_id']}/draw", params={"draw_id": draw_id}, - headers={"Authorization": f"Bearer {api_key}"}, + headers=headers, ) assert response.status_code == 200 assert response.json()["draw_id"] == draw_id response = client.put( f"/mab/{mab['experiment_id']}/{draw_id}/1", - headers={"Authorization": f"Bearer {api_key}"}, + headers=headers, ) assert response.status_code == 200 n_processed = await process_notifications(asession) diff --git a/backend/tests/test_users.py b/backend/tests/test_users.py index c1d3f92..f966273 100644 --- a/backend/tests/test_users.py +++ b/backend/tests/test_users.py @@ -9,7 +9,7 @@ from backend.app.auth.dependencies import get_current_user, get_verified_user from backend.app.users.models import UserDB -from .config import TEST_PASSWORD, TEST_USER_API_KEY, TEST_USERNAME +from .config import TEST_PASSWORD @fixture @@ -33,10 +33,11 @@ def admin_token(self, client: TestClient) -> str: return token @fixture - def user_token(self, client: TestClient, regular_user: int) -> str: + def user_token(self, client: TestClient, regular_user: tuple) -> str: + user_id, username, _ = regular_user response = client.post( "/login", - data={"username": TEST_USERNAME, "password": TEST_PASSWORD}, + data={"username": username, "password": TEST_PASSWORD}, ) token = response.json()["access_token"] return token @@ -72,10 +73,11 @@ def test_user_id_2_cannot_create_user( self, client: TestClient, mock_send_email: MagicMock ) -> None: # Register a user + username = f"user_test1_{os.urandom(4).hex()}" response = client.post( "/user/", json={ - "username": "user_test1", + "username": username, "password": "password_test", "first_name": "Test", "last_name": "User", @@ -83,11 +85,11 @@ def test_user_id_2_cannot_create_user( ) assert response.status_code == 200 - # Try to register another user + # Try to register another user with the same username response = client.post( "/user/", json={ - "username": "user_test1", + "username": username, "password": "password_test", "first_name": "Test", "last_name": "User", @@ -96,28 +98,47 @@ def test_user_id_2_cannot_create_user( assert response.status_code == 400 def test_get_current_user( - self, client: TestClient, user_token: str, regular_user: int + self, client: TestClient, user_token: str, regular_user: tuple ) -> None: + user_id, username, _ = regular_user response = client.get( "/user/", headers={"Authorization": f"Bearer {user_token}"}, ) assert response.status_code == 200 - assert response.json()["user_id"] == regular_user - assert response.json()["username"] == TEST_USERNAME + assert response.json()["user_id"] == user_id + assert response.json()["username"] == username - def test_rotate_key( - self, client: TestClient, user_token: str, mock_verified_user: None + def test_login_creates_default_workspace( + self, client: TestClient, mock_send_email: MagicMock ) -> None: - response = client.get( - "/user/", headers={"Authorization": f"Bearer {user_token}"} + # Register a new user + test_username = f"workspace_user_{os.urandom(4).hex()}@test.com" + response = client.post( + "/user/", + json={ + "username": test_username, + "password": "password_test", + "first_name": "Workspace", + "last_name": "User", + }, ) assert response.status_code == 200 - assert response.json()["api_key_first_characters"] == TEST_USER_API_KEY[:5] - response = client.put( - "/user/rotate-key", headers={"Authorization": f"Bearer {user_token}"} + # Login with the new user + response = client.post( + "/login", + data={"username": test_username, "password": "password_test"}, + ) + assert response.status_code == 200 + token = response.json()["access_token"] + + # Check if a default workspace was created for the user + response = client.get( + "/workspace/current", + headers={"Authorization": f"Bearer {token}"}, ) assert response.status_code == 200 - assert response.json()["new_api_key"] != TEST_USER_API_KEY + assert response.json()["workspace_name"] == f"{test_username}'s Workspace" + assert response.json()["is_default"] is True diff --git a/backend/tests/test_workspace.py b/backend/tests/test_workspace.py new file mode 100644 index 0000000..f865507 --- /dev/null +++ b/backend/tests/test_workspace.py @@ -0,0 +1,319 @@ +import os +from typing import Annotated, Generator, cast +from unittest.mock import MagicMock, patch + +from fastapi import Depends, FastAPI +from fastapi.testclient import TestClient +from pytest import fixture +from sqlalchemy.ext.asyncio import AsyncSession + +from backend.app.auth.dependencies import get_current_user, get_verified_user +from backend.app.database import get_async_session +from backend.app.users.models import UserDB +from backend.app.workspaces.models import UserRoles + +from .config import TEST_PASSWORD + + +@fixture +def mock_send_email() -> Generator[MagicMock, None, None]: + with patch("backend.app.email.EmailService._send_email") as mocked_send: + mocked_send.return_value = {"MessageId": "mock-message-id"} + yield mocked_send + + +class TestWorkspace: + @fixture + def user_token(self, client: TestClient, regular_user: tuple) -> str: + user_id, username, _ = regular_user + response = client.post( + "/login", + data={"username": username, "password": TEST_PASSWORD}, + ) + token = response.json()["access_token"] + return token + + @fixture + def mock_verified_user(self, client: TestClient) -> Generator[None, None, None]: + async def mock_get_verified_user( + user_db: Annotated[UserDB, Depends(get_current_user)], + ) -> UserDB: + return user_db + + app = cast(FastAPI, client.app) + app.dependency_overrides[get_verified_user] = mock_get_verified_user + yield + app.dependency_overrides.clear() + + @fixture + def mock_async_session(self, client: TestClient) -> Generator[None, None, None]: + async def override_get_async_session() -> AsyncSession: + # This would need to be implemented to return a test async session + pass + + app = cast(FastAPI, client.app) + app.dependency_overrides[get_async_session] = override_get_async_session + yield + app.dependency_overrides.clear() + + def test_get_current_workspace( + self, client: TestClient, user_token: str, mock_verified_user: None + ) -> None: + response = client.get( + "/workspace/current", + headers={"Authorization": f"Bearer {user_token}"}, + ) + assert response.status_code == 200 + assert "workspace_name" in response.json() + assert "is_default" in response.json() + assert response.json()["is_default"] is True + + def test_retrieve_all_workspaces( + self, client: TestClient, user_token: str, mock_verified_user: None + ) -> None: + response = client.get( + "/workspace/", + headers={"Authorization": f"Bearer {user_token}"}, + ) + assert response.status_code == 200 + assert isinstance(response.json(), list) + # At least one workspace (default) should exist + assert len(response.json()) >= 1 + + def test_create_workspace( + self, client: TestClient, user_token: str, mock_verified_user: None + ) -> None: + workspace_name = "Test Workspace Creation" + response = client.post( + "/workspace/", + headers={"Authorization": f"Bearer {user_token}"}, + json={ + "workspace_name": workspace_name, + "api_daily_quota": 1000, + "content_quota": 50, + }, + ) + assert response.status_code == 200 + assert response.json()["workspace_name"] == workspace_name + assert response.json()["api_daily_quota"] == 1000 + assert response.json()["content_quota"] == 50 + + # Verify the workspace exists in the list of workspaces + response = client.get( + "/workspace/", + headers={"Authorization": f"Bearer {user_token}"}, + ) + assert response.status_code == 200 + workspaces = response.json() + assert any(ws["workspace_name"] == workspace_name for ws in workspaces) + + def test_update_workspace( + self, client: TestClient, user_token: str, mock_verified_user: None + ) -> None: + # First create a workspace + original_name = "Original Workspace Name" + response = client.post( + "/workspace/", + headers={"Authorization": f"Bearer {user_token}"}, + json={"workspace_name": original_name}, + ) + assert response.status_code == 200 + workspace_id = response.json()["workspace_id"] + + # Now update the workspace name + new_name = "Updated Workspace Name" + response = client.put( + f"/workspace/{workspace_id}", + headers={"Authorization": f"Bearer {user_token}"}, + json={"workspace_name": new_name}, + ) + assert response.status_code == 200 + assert response.json()["workspace_name"] == new_name + + # Verify the workspace has been updated in the list + response = client.get( + "/workspace/", + headers={"Authorization": f"Bearer {user_token}"}, + ) + assert response.status_code == 200 + workspaces = response.json() + assert any(ws["workspace_name"] == new_name for ws in workspaces) + assert not any(ws["workspace_name"] == original_name for ws in workspaces) + + def test_switch_workspace( + self, client: TestClient, user_token: str, mock_verified_user: None + ) -> None: + # First create a new workspace + workspace_name = "Test Switch Workspace" + response = client.post( + "/workspace/", + headers={"Authorization": f"Bearer {user_token}"}, + json={"workspace_name": workspace_name}, + ) + assert response.status_code == 200 + + # Now switch to the new workspace + response = client.post( + "/workspace/switch", + headers={"Authorization": f"Bearer {user_token}"}, + json={"workspace_name": workspace_name}, + ) + assert response.status_code == 200 + assert "access_token" in response.json() + + # Verify that the current workspace is now the one we switched to + new_token = response.json()["access_token"] + response = client.get( + "/workspace/current", + headers={"Authorization": f"Bearer {new_token}"}, + ) + assert response.status_code == 200 + assert response.json()["workspace_name"] == workspace_name + + def test_rotate_workspace_api_key( + self, client: TestClient, user_token: str, mock_verified_user: None + ) -> None: + response = client.put( + "/workspace/rotate-key", + headers={"Authorization": f"Bearer {user_token}"}, + ) + assert response.status_code == 200 + assert "new_api_key" in response.json() + assert "workspace_name" in response.json() + + def test_invite_user_to_workspace( + self, + client: TestClient, + user_token: str, + mock_verified_user: None, + mock_send_email: MagicMock, + ) -> None: + # First create a non-default workspace + workspace_name = "Test Invite Workspace" + response = client.post( + "/workspace/", + headers={"Authorization": f"Bearer {user_token}"}, + json={"workspace_name": workspace_name}, + ) + assert response.status_code == 200 + + # Now invite a user to this workspace + invite_email = ( + f"invited_user_{os.urandom(4).hex()}@example.com" # Use unique email + ) + response = client.post( + "/workspace/invite", + headers={"Authorization": f"Bearer {user_token}"}, + json={ + "workspace_name": workspace_name, + "email": invite_email, + "role": UserRoles.READ_ONLY, + }, + ) + assert response.status_code == 200 + assert response.json()["email"] == invite_email + assert response.json()["workspace_name"] == workspace_name + assert "message" in response.json() + + # Verify that email was called + mock_send_email.assert_called_once() + + def test_get_workspace_users( + self, client: TestClient, user_token: str, mock_verified_user: None + ) -> None: + # First create a workspace + workspace_name = "Test Workspace Users" + response = client.post( + "/workspace/", + headers={"Authorization": f"Bearer {user_token}"}, + json={"workspace_name": workspace_name}, + ) + assert response.status_code == 200 + workspace_id = response.json()["workspace_id"] + + # Get users in workspace + response = client.get( + f"/workspace/{workspace_id}/users", + headers={"Authorization": f"Bearer {user_token}"}, + ) + assert response.status_code == 200 + assert isinstance(response.json(), list) + # Creator should be in the workspace with ADMIN role + users = response.json() + assert len(users) >= 1 + # Since username is now dynamic, just check that at least one user + # has ADMIN role + assert any(user["role"] == UserRoles.ADMIN for user in users) + + def test_get_workspace_by_id( + self, client: TestClient, user_token: str, mock_verified_user: None + ) -> None: + # First create a workspace + workspace_name = "Test Get Workspace" + response = client.post( + "/workspace/", + headers={"Authorization": f"Bearer {user_token}"}, + json={"workspace_name": workspace_name}, + ) + assert response.status_code == 200 + workspace_id = response.json()["workspace_id"] + + # Get workspace by ID + response = client.get( + f"/workspace/{workspace_id}", + headers={"Authorization": f"Bearer {user_token}"}, + ) + assert response.status_code == 200 + assert response.json()["workspace_id"] == workspace_id + assert response.json()["workspace_name"] == workspace_name + + def test_get_workspace_key_history( + self, client: TestClient, user_token: str, mock_verified_user: None + ) -> None: + # First create a workspace + workspace_name = "Test Key History" + response = client.post( + "/workspace/", + headers={"Authorization": f"Bearer {user_token}"}, + json={"workspace_name": workspace_name}, + ) + assert response.status_code == 200 + workspace_id = response.json()["workspace_id"] + + # Rotate API key to create a history record + response = client.put( + "/workspace/rotate-key", + headers={"Authorization": f"Bearer {user_token}"}, + ) + assert response.status_code == 200 + + # Wait a moment to ensure DB transactions complete + import time + + time.sleep(0.5) + + # Switch to the newly created workspace + response = client.post( + "/workspace/switch", + headers={"Authorization": f"Bearer {user_token}"}, + json={"workspace_name": workspace_name}, + ) + assert response.status_code == 200 + new_token = response.json()["access_token"] + + # Now rotate the key for this specific workspace + response = client.put( + "/workspace/rotate-key", + headers={"Authorization": f"Bearer {new_token}"}, + ) + assert response.status_code == 200 + + # Get key rotation history + response = client.get( + f"/workspace/{workspace_id}/key-history", + headers={"Authorization": f"Bearer {new_token}"}, + ) + assert response.status_code == 200 + + assert isinstance(response.json(), list) + assert len(response.json()) >= 1 From a84527bd282adce72d99138afd3e0e954589978d Mon Sep 17 00:00:00 2001 From: Jay Prakash <0freerunning@gmail.com> Date: Mon, 5 May 2025 00:18:29 +0300 Subject: [PATCH 13/74] Fix npm build errors --- .../workspaces/[workspaceId]/page.tsx | 16 ++++--- .../[workspaceId]/users/invite/page.tsx | 44 +++++++++---------- frontend/src/components/app-sidebar.tsx | 1 - .../components/create-workspace-dialog.tsx | 5 ++- .../src/components/workspace-switcher.tsx | 1 - 5 files changed, 35 insertions(+), 32 deletions(-) diff --git a/frontend/src/app/(protected)/workspaces/[workspaceId]/page.tsx b/frontend/src/app/(protected)/workspaces/[workspaceId]/page.tsx index ecfdaae..b370b77 100644 --- a/frontend/src/app/(protected)/workspaces/[workspaceId]/page.tsx +++ b/frontend/src/app/(protected)/workspaces/[workspaceId]/page.tsx @@ -51,8 +51,10 @@ import { Workspace, WorkspaceUser, ApiKeyRotation } from "../types"; export default function WorkspaceDetailPage() { const params = useParams(); const router = useRouter(); - const { token, currentWorkspace, fetchWorkspaces, switchWorkspace } = + const { user: currentUser, token, currentWorkspace, fetchWorkspaces, switchWorkspace } = useAuth(); + + console.log("Current workspace:", currentWorkspace); const { toast } = useToast(); const [workspace, setWorkspace] = useState(null); @@ -149,10 +151,11 @@ export default function WorkspaceDetailPage() { description: "API key rotated successfully", variant: "success", }); - } catch (error: any) { + } catch (error: Error | unknown) { + const errorMessage = error instanceof Error ? error.message : "Failed to rotate API key"; toast({ title: "Error", - description: error.message || "Failed to rotate API key", + description: errorMessage, variant: "destructive", }); } finally { @@ -185,10 +188,11 @@ export default function WorkspaceDetailPage() { description: `${username} has been removed from the workspace`, variant: "success", }); - } catch (error: any) { + } catch (error: unknown) { + const errorMessage = error instanceof Error ? error.message : "Failed to remove user"; toast({ title: "Error", - description: error.message || "Failed to remove user", + description: errorMessage, variant: "destructive", }); } @@ -422,7 +426,7 @@ export default function WorkspaceDetailPage() { // Find the current user to determine if they have admin rights const isCurrentUserAdmin = workspaceUsers.find( - (u) => u.username === currentWorkspace?.username + (u) => u.username === currentUser )?.role === "admin"; return ( diff --git a/frontend/src/app/(protected)/workspaces/[workspaceId]/users/invite/page.tsx b/frontend/src/app/(protected)/workspaces/[workspaceId]/users/invite/page.tsx index d2d2644..67af1eb 100644 --- a/frontend/src/app/(protected)/workspaces/[workspaceId]/users/invite/page.tsx +++ b/frontend/src/app/(protected)/workspaces/[workspaceId]/users/invite/page.tsx @@ -66,9 +66,9 @@ export default function InviteUserPage() { const { toast } = useToast(); const [isInviting, setIsInviting] = useState(false); const [inviteSuccess, setInviteSuccess] = useState(false); - + const workspaceId = Number(params.workspaceId); - + // Initialize form const form = useForm({ resolver: zodResolver(InviteFormSchema), @@ -77,23 +77,23 @@ export default function InviteUserPage() { role: "read_only", // Default to read-only }, }); - + const onSubmit = async (data: InviteFormValues) => { if (!token) return; - + setIsInviting(true); try { // First need to get workspace name const workspace = await apiCalls.getWorkspaceById(token, workspaceId); - + // Send invitation - const response = await apiCalls.inviteUserToWorkspace( + await apiCalls.inviteUserToWorkspace( token, data.email, workspace.workspace_name, data.role ); - + // Show success message setInviteSuccess(true); toast({ @@ -101,20 +101,20 @@ export default function InviteUserPage() { description: `${data.email} has been invited to the workspace`, variant: "success", }); - + // Reset form form.reset(); - } catch (error: any) { + } catch (error: Error | unknown) { toast({ title: "Error", - description: error.message || "Failed to send invitation", + description: error instanceof Error ? error.message : "Failed to send invitation", variant: "destructive", }); } finally { setIsInviting(false); } }; - + return (
@@ -138,7 +138,7 @@ export default function InviteUserPage() {
- +
@@ -151,7 +151,7 @@ export default function InviteUserPage() { Back
- + Invite User to Workspace @@ -178,15 +178,15 @@ export default function InviteUserPage() { )} /> - + ( Role - + + + The email address of the person you want to invite + + + + )} + /> + + ( + + Role + + + Administrators can manage workspace settings and users. Regular users have read-only access. + + + + )} + /> + + + + + + {inviteSuccess && ( +
+

+ Invitation Sent Successfully +

+

+ An email has been sent to the user with instructions to join the workspace. +

+
+ )} + + + + +
+
+ ); +} diff --git a/frontend/src/app/(protected)/workspaces/page.tsx b/frontend/src/app/(protected)/workspaces/page.tsx new file mode 100644 index 0000000..872d091 --- /dev/null +++ b/frontend/src/app/(protected)/workspaces/page.tsx @@ -0,0 +1,243 @@ +"use client"; + +import React, { useState, useEffect } from "react"; +import { useAuth } from "@/utils/auth"; +import { Building, Plus, Settings } from "lucide-react"; +import { Button } from "@/components/ui/button"; +import { + Card, + CardContent, + CardDescription, + CardFooter, + CardHeader, + CardTitle, +} from "@/components/ui/card"; +import { + Breadcrumb, + BreadcrumbItem, + BreadcrumbLink, + BreadcrumbList, + BreadcrumbPage, + BreadcrumbSeparator, +} from "@/components/ui/breadcrumb"; +import Hourglass from "@/components/Hourglass"; +import { Badge } from "@/components/ui/badge"; +import { Separator } from "@/components/ui/separator"; +import Link from "next/link"; +import { CreateWorkspaceDialog } from "@/components/create-workspace-dialog"; +import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs"; +import router from "next/router"; + +export default function WorkspacesPage() { + const { workspaces, currentWorkspace, fetchWorkspaces, switchWorkspace } = useAuth(); + const [isLoading, setIsLoading] = useState(true); + const [isCreateDialogOpen, setIsCreateDialogOpen] = useState(false); + + useEffect(() => { + const loadWorkspaces = async () => { + setIsLoading(true); + try { + await fetchWorkspaces(); + } catch (error) { + console.error("Error loading workspaces:", error); + } finally { + setIsLoading(false); + } + }; + + loadWorkspaces(); + }, []); + + const handleWorkspaceSwitch = async (workspaceName: string) => { + try { + setIsLoading(true); + await switchWorkspace(workspaceName); + } catch (error) { + console.error("Error switching workspace:", error); + } finally { + setIsLoading(false); + } + }; + + if (isLoading) { + return ( +
+
+ + Loading workspaces... +
+
+ ); + } + + return ( + <> +
+
+ + + + Home + + + + Workspaces + + + +
+ +
+
+

Workspaces

+

+ Manage your workspaces and team members +

+
+ +
+ + + + All Workspaces + Current Workspace + + + +
+ {workspaces.map((workspace) => ( + + +
+ {workspace.is_default && ( + Default + )} + {currentWorkspace?.workspace_id === workspace.workspace_id && ( + Current + )} +
+
+
+ +
+ {workspace.workspace_name} +
+ + ID: {workspace.workspace_id} + +
+ +
+
+ API Quota: + {workspace.api_daily_quota} calls/day +
+
+ Content Quota: + {workspace.content_quota} experiments +
+
+ API Key: + {workspace.api_key_first_characters}••••• +
+
+ Created: + {new Date(workspace.created_datetime_utc).toLocaleDateString()} +
+
+
+ + + + + + + +
+ ))} +
+
+ + + {currentWorkspace && ( +
+ + +
+
+ +
+
+ {currentWorkspace.workspace_name} + + Workspace ID: {currentWorkspace.workspace_id} + +
+
+
+ +
+
+

Workspace Details

+
+
Created:
+
{new Date(currentWorkspace.created_datetime_utc).toLocaleDateString()}
+ +
Last Updated:
+
{new Date(currentWorkspace.updated_datetime_utc).toLocaleDateString()}
+ +
API Daily Quota:
+
{currentWorkspace.api_daily_quota} calls/day
+ +
Content Quota:
+
{currentWorkspace.content_quota} experiments
+
+
+ +
+

API Configuration

+
+
API Key Prefix:
+
{currentWorkspace.api_key_first_characters}•••••
+ +
Key Last Rotated:
+
{new Date(currentWorkspace.api_key_updated_datetime_utc).toLocaleDateString()}
+
+
+
+
+ + + +
+
+ )} +
+
+
+ + + + ); +} diff --git a/frontend/src/app/(protected)/workspaces/types.ts b/frontend/src/app/(protected)/workspaces/types.ts new file mode 100644 index 0000000..f604ff5 --- /dev/null +++ b/frontend/src/app/(protected)/workspaces/types.ts @@ -0,0 +1,32 @@ +export interface Workspace { + workspace_id: number; + workspace_name: string; + api_daily_quota: number; + content_quota: number; + api_key_first_characters: string; + api_key_updated_datetime_utc: string; + api_key_rotated_by_user_id?: number; + api_key_rotated_by_username?: string; + created_datetime_utc: string; + updated_datetime_utc: string; + is_default: boolean; +} + +export interface WorkspaceUser { + user_id: number; + username: string; + first_name: string; + last_name: string; + role: string; + is_default_workspace: boolean; + created_datetime_utc: string; +} + +export interface ApiKeyRotation { + rotation_id: number; + workspace_id: number; + rotated_by_user_id: number; + rotated_by_username: string; + key_first_characters: string; + rotation_datetime_utc: string; +} diff --git a/frontend/src/components/app-sidebar.tsx b/frontend/src/components/app-sidebar.tsx index c56c58d..cf2666a 100644 --- a/frontend/src/components/app-sidebar.tsx +++ b/frontend/src/components/app-sidebar.tsx @@ -1,16 +1,9 @@ "use client"; import * as React from "react"; import { - AudioWaveform, - ArrowLeftRightIcon, LayoutDashboardIcon, - Command, - Frame, - GalleryVerticalEnd, - Map, - PieChart, - Settings2, FlaskConicalIcon, + Settings2, } from "lucide-react"; import { NavMain } from "@/components/nav-main"; import { NavRecentExperiments } from "@/components/nav-recent-experiments"; @@ -23,79 +16,19 @@ import { SidebarHeader, SidebarRail, } from "@/components/ui/sidebar"; -import api from "@/utils/api"; -import { apiCalls } from "@/utils/api"; import { useAuth } from "@/utils/auth"; -type UserDetails = { - username: string; - firstName: string; - lastName: string; - isActive: boolean; - isVerified: boolean; -}; - -const getUserDetails = async (token: string | null) => { - try { - if (token) { - const response = await apiCalls.getUser(token); - if (!response) { - throw new Error("No response from server"); - } - return { - username: response.username, - firstName: response.first_name, - lastName: response.last_name, - isActive: response.is_active, - isVerified: response.is_verified, - } as UserDetails; - } else { - throw new Error("No token provided"); - } -} catch (error: unknown) { - if (error instanceof Error) { - throw new Error(`Error fetching user details: ${error.message}`); - } else { - throw new Error("Error fetching user details"); - } - } -}; +const AppSidebar = React.memo(function AppSidebar({ + ...props +}: React.ComponentProps) { + const { user, firstName, lastName } = useAuth(); -// This is sample data. -const data = { - user: { - name: "shadcn", - email: "m@example.com", - avatar: "/avatars/shadcn.jpg", - }, - workspaces: [ - { - name: "Acme Inc", - logo: GalleryVerticalEnd, - plan: "Enterprise", - }, - { - name: "Acme Corp.", - logo: AudioWaveform, - plan: "Startup", - }, - { - name: "Evil Corp.", - logo: Command, - plan: "Free", - }, - ], - navMain: [ + const navMain = [ { title: "Experiments", url: "/experiments", icon: FlaskConicalIcon, }, - { - title: "Integration", - url: "/integration", - icon: ArrowLeftRightIcon, - }, { title: "Dashboard", url: "#", @@ -106,48 +39,40 @@ const data = { url: "#", icon: Settings2, }, - ], - recentExperiments: [ + ]; + + const recentExperiments = [ { name: "New onboarding flows", url: "#", - icon: Frame, + icon: FlaskConicalIcon, }, { name: "3 different voices", url: "#", - icon: PieChart, + icon: FlaskConicalIcon, }, { name: "AI responses", url: "#", - icon: Map, + icon: FlaskConicalIcon, }, - ], -}; -const AppSidebar = React.memo(function AppSidebar({ - ...props -}: React.ComponentProps) { - const { token } = useAuth(); - const [userDetails, setUserDetails] = React.useState( - null - ); + ]; + + const userDetails = { + firstName: firstName || "?", + lastName: lastName || "?", + username: user || "loading", + }; - React.useEffect(() => { - if (token) { - getUserDetails(token) - .then((data) => setUserDetails(data)) - .catch((error) => console.error(error)); - } - }, [token]); return ( - + - - + + diff --git a/frontend/src/components/create-workspace-dialog.tsx b/frontend/src/components/create-workspace-dialog.tsx new file mode 100644 index 0000000..f349fba --- /dev/null +++ b/frontend/src/components/create-workspace-dialog.tsx @@ -0,0 +1,153 @@ +"use client"; + +import { useState } from "react"; +import { apiCalls } from "@/utils/api"; +import { useAuth } from "@/utils/auth"; +import { useToast } from "@/hooks/use-toast"; +import { useRouter } from "next/navigation"; +import { z } from "zod"; +import { useForm } from "react-hook-form"; +import { zodResolver } from "@hookform/resolvers/zod"; + +import { + Dialog, + DialogContent, + DialogDescription, + DialogFooter, + DialogHeader, + DialogTitle, +} from "@/components/ui/dialog"; +import { Button } from "@/components/ui/button"; +import { + Form, + FormControl, + FormDescription, + FormField, + FormItem, + FormLabel, + FormMessage, +} from "@/components/ui/form"; +import { Input } from "@/components/ui/input"; + +// Create schema for workspace form +const WorkspaceFormSchema = z.object({ + workspaceName: z + .string() + .min(3, "Workspace name must be at least 3 characters") + .max(50, "Workspace name must be less than 50 characters"), + apiDailyQuota: z + .number() + .int("API quota must be an integer") + .min(1, "API quota must be at least 1") + .optional(), + contentQuota: z + .number() + .int("Content quota must be an integer") + .min(1, "Content quota must be at least 1") + .optional(), +}); + +type WorkspaceFormValues = z.infer; + +interface CreateWorkspaceDialogProps { + open: boolean; + onOpenChange: (open: boolean) => void; +} + +export function CreateWorkspaceDialog({ + open, + onOpenChange, +}: CreateWorkspaceDialogProps) { + const { token, fetchWorkspaces } = useAuth(); + const router = useRouter(); + const { toast } = useToast(); + const [isCreating, setIsCreating] = useState(false); + + // Initialize form + const form = useForm({ + resolver: zodResolver(WorkspaceFormSchema), + defaultValues: { + workspaceName: "", + apiDailyQuota: undefined, + contentQuota: undefined, + }, + }); + + async function onSubmit(data: WorkspaceFormValues) { + if (!token) return; + + setIsCreating(true); + try { + await apiCalls.createWorkspace( + token, + data.workspaceName, + data.apiDailyQuota, + data.contentQuota + ); + + toast({ + title: "Workspace created", + description: `${data.workspaceName} workspace was created successfully.`, + variant: "success", + }); + + // Refresh the workspaces list + await fetchWorkspaces(); + + // Close the dialog and reset form + onOpenChange(false); + form.reset(); + + // Navigate to workspaces page + router.push("/workspaces"); + } catch (error: Error | unknown) { + const errorMessage = error instanceof Error ? error.message : "Failed to create workspace. Please try again."; + toast({ + title: "Error creating workspace", + description: errorMessage, + variant: "destructive", + }); + } finally { + setIsCreating(false); + } + } + + return ( + + + + Create New Workspace + + Create a new workspace for your team or project + + + +
+ + ( + + Workspace Name + + + + + This is the name that will be displayed for your workspace + + + + )} + /> + + + + + +
+
+ ); +} diff --git a/frontend/src/components/nav-user.tsx b/frontend/src/components/nav-user.tsx index a7b1994..9abe095 100644 --- a/frontend/src/components/nav-user.tsx +++ b/frontend/src/components/nav-user.tsx @@ -18,6 +18,7 @@ import { SidebarMenuItem, useSidebar, } from "@/components/ui/sidebar"; +import Link from "next/link"; export function NavUser({ user, @@ -97,10 +98,12 @@ export function NavUser({ Account - - - Manage Workspace - + + + + Manage Workspace + + diff --git a/frontend/src/components/ui/alert-dialog.tsx b/frontend/src/components/ui/alert-dialog.tsx new file mode 100644 index 0000000..9f16798 --- /dev/null +++ b/frontend/src/components/ui/alert-dialog.tsx @@ -0,0 +1,143 @@ +"use client" + +import * as React from "react" +import * as AlertDialogPrimitive from "@radix-ui/react-alert-dialog" + +import { cn } from "@/lib/utils" +import { buttonVariants } from "@/components/ui/button" + +const AlertDialog = AlertDialogPrimitive.Root +const AlertDialogTrigger = AlertDialogPrimitive.Trigger +const AlertDialogPortal = AlertDialogPrimitive.Portal + +const AlertDialogOverlay = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +AlertDialogOverlay.displayName = AlertDialogPrimitive.Overlay.displayName + +const AlertDialogContent = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + + + + +)) +AlertDialogContent.displayName = AlertDialogPrimitive.Content.displayName + +const AlertDialogHeader = ({ + className, + ...props +}: React.HTMLAttributes) => ( +
+) +AlertDialogHeader.displayName = "AlertDialogHeader" + +const AlertDialogFooter = ({ + className, + ...props +}: React.HTMLAttributes) => ( +
+) +AlertDialogFooter.displayName = "AlertDialogFooter" + +const AlertDialogTitle = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +AlertDialogTitle.displayName = AlertDialogPrimitive.Title.displayName + +const AlertDialogDescription = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +AlertDialogDescription.displayName = + AlertDialogPrimitive.Description.displayName + +const AlertDialogAction = React.forwardRef< + HTMLButtonElement, + React.ButtonHTMLAttributes +>(({ className, ...props }, ref) => ( +
) : ( @@ -421,87 +428,74 @@ export default function WorkspaceDetailPage() {
Joined
Actions
- {workspaceUsers.map((user) => { - // Find the current user to determine if they have admin rights - const isCurrentUserAdmin = - workspaceUsers.find( - (u) => u.username === currentUser - )?.role === "admin"; - - return ( -
-
-
- {user.first_name} {user.last_name} -
-
- {user.username} -
-
-
- {user.role.toLowerCase()} - {user.is_default_workspace && ( - - Default - - )} + {workspaceUsers.map((user) => ( +
+
+
+ {user.first_name} {user.last_name}
-
- {new Date( - user.created_datetime_utc - ).toLocaleDateString()} +
+ {user.username}
-
- {!workspace.is_default && isCurrentUserAdmin && ( - - - + + + + + Remove user? + + + Are you sure you want to remove{" "} + {user.first_name} {user.last_name} from + this workspace? + + + + + Cancel + + + handleRemoveUser(user.username) + } > Remove - - - - - - Remove user? - - - Are you sure you want to remove{" "} - {user.first_name} {user.last_name} from - this workspace? - - - - { - e.preventDefault(); - e.stopPropagation(); - }} - > - Cancel - - - handleRemoveUser(user.username) - } - > - Remove - - - - - )} -
+ + + + + )}
- ); - })} +
+ ))}
)} @@ -513,14 +507,16 @@ export default function WorkspaceDetailPage() { API Configuration - + {isCurrentUserAdmin && ( + + )} Manage API settings for this workspace @@ -617,18 +613,20 @@ export default function WorkspaceDetailPage() { -
-

- About API Key Rotation -

-

- When you rotate your API key, the old key will be - immediately invalidated. Any services or applications using - the old key will need to be updated with the new key. Make - sure to copy and save your new key as it will only be shown - once. -

-
+ {isCurrentUserAdmin && ( +
+

+ About API Key Rotation +

+

+ When you rotate your API key, the old key will be + immediately invalidated. Any services or applications using + the old key will need to be updated with the new key. Make + sure to copy and save your new key as it will only be shown + once. +

+
+ )} From 2cbf107346e76c5aa1a4e88a346131edf71e8e86 Mon Sep 17 00:00:00 2001 From: Jay Prakash <0freerunning@gmail.com> Date: Wed, 21 May 2025 21:18:40 +0530 Subject: [PATCH 26/74] Fix the removal of user from workspace --- backend/app/workspaces/models.py | 23 +++++++++++ frontend/src/components/ui/alert-dialog.tsx | 24 ++++++----- .../src/components/workspace-switcher.tsx | 40 ++++++++++++++++--- frontend/src/utils/api.ts | 12 ++++++ 4 files changed, 83 insertions(+), 16 deletions(-) diff --git a/backend/app/workspaces/models.py b/backend/app/workspaces/models.py index ab907fb..b79aeab 100644 --- a/backend/app/workspaces/models.py +++ b/backend/app/workspaces/models.py @@ -254,8 +254,31 @@ async def remove_user_from_workspace( f"'{workspace_db.workspace_name}'." ) + was_default = user_workspace.default_workspace + # Delete the relationship await asession.delete(user_workspace) + + # If this was their default workspace, reset to their personal default workspace + if was_default: + stmt = ( + select(UserWorkspaceDB) + .join(WorkspaceDB, UserWorkspaceDB.workspace_id == WorkspaceDB.workspace_id) + .where( + and_( + UserWorkspaceDB.user_id == user_db.user_id, + WorkspaceDB.is_default.is_(True), + ) + ) + ) + result = await asession.execute(stmt) + personal_workspace = result.scalar_one_or_none() + + if personal_workspace: + personal_workspace.default_workspace = True + personal_workspace.updated_datetime_utc = datetime.now(timezone.utc) + asession.add(personal_workspace) + await asession.commit() diff --git a/frontend/src/components/ui/alert-dialog.tsx b/frontend/src/components/ui/alert-dialog.tsx index 1e2629a..fdd68b2 100644 --- a/frontend/src/components/ui/alert-dialog.tsx +++ b/frontend/src/components/ui/alert-dialog.tsx @@ -114,17 +114,19 @@ AlertDialogAction.displayName = "AlertDialogAction" const AlertDialogCancel = React.forwardRef< HTMLButtonElement, - React.ButtonHTMLAttributes + React.ButtonHTMLAttributes & { asChild?: boolean } >(({ className, ...props }, ref) => ( -
- {currentWorkspace.workspace_name} + {currentWorkspace?.workspace_name || "No Workspace"} {isLoading ? "Switching..." : "Workspace"} @@ -109,7 +109,7 @@ export function WorkspaceSwitcher() { handleWorkspaceSwitch(workspace.workspace_name)} - className={`gap-2 p-2 ${workspace.workspace_id === currentWorkspace.workspace_id ? "bg-accent" : ""}`} + className={`gap-2 p-2 ${currentWorkspace && workspace.workspace_id === currentWorkspace.workspace_id ? "bg-accent" : ""}`} >
From 2783946fcdab4d9c92dc9acbffef39dabf451c09 Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Mon, 26 May 2025 17:17:09 +0300 Subject: [PATCH 28/74] WIP: models and schemas --- backend/app/experiments/models.py | 471 +++++++++++++++++++++++++++ backend/app/experiments/schemas.py | 502 +++++++++++++++++++++++++++++ 2 files changed, 973 insertions(+) create mode 100644 backend/app/experiments/models.py create mode 100644 backend/app/experiments/schemas.py diff --git a/backend/app/experiments/models.py b/backend/app/experiments/models.py new file mode 100644 index 0000000..ffcdc9e --- /dev/null +++ b/backend/app/experiments/models.py @@ -0,0 +1,471 @@ +import uuid +from datetime import datetime +from typing import TYPE_CHECKING, Optional, Sequence + +from sqlalchemy import ( + Boolean, + DateTime, + Enum, + Float, + ForeignKey, + Integer, + String, + select, +) +from sqlalchemy.dialects.postgresql import ARRAY +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship + +from .schemas import ( + AutoFailUnitType, + EventType, + Notifications, + ObservationType, +) + +if TYPE_CHECKING: + from .workspaces.models import WorkspaceDB + + +# Base class for SQLAlchemy models +class Base(DeclarativeBase): + """Base class for SQLAlchemy models""" + + pass + + +# --- Base model for experiments --- +class ExperimentDB(Base): + """ + Base model for experiments. + """ + + __tablename__ = "experiments" + + # IDs + experiment_id: Mapped[int] = mapped_column( + Integer, primary_key=True, nullable=False + ) + user_id: Mapped[int] = mapped_column( + Integer, ForeignKey("users.user_id"), nullable=False + ) + workspace_id: Mapped[int] = mapped_column( + Integer, ForeignKey("workspace.workspace_id"), nullable=False + ) + + # Description + name: Mapped[str] = mapped_column(String(length=150), nullable=False) + description: Mapped[str] = mapped_column(String(length=500), nullable=False) + is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) + + # Assignments config + sticky_assignment: Mapped[bool] = mapped_column( + Boolean, nullable=False, default=False + ) + auto_fail: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + auto_fail_value: Mapped[int] = mapped_column(Integer, nullable=True) + auto_fail_unit: Mapped[AutoFailUnitType] = mapped_column( + Enum(AutoFailUnitType), nullable=True + ) + + # Experiment config + exp_type: Mapped[str] = mapped_column(String(length=50), nullable=False) + prior_type: Mapped[str] = mapped_column(String(length=50), nullable=False) + reward_type: Mapped[str] = mapped_column(String(length=50), nullable=False) + + # State variables + created_datetime_utc: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False + ) + n_trials: Mapped[int] = mapped_column(Integer, nullable=False) + last_trial_datetime_utc: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=True + ) + + # Relationships + workspace: Mapped["WorkspaceDB"] = relationship( + "WorkspaceDB", back_populates="experiments" + ) + arms: Mapped[list["ArmDB"]] = relationship( + "ArmDB", back_populates="experiment", lazy="joined" + ) + draws: Mapped[list["DrawDB"]] = relationship( + "DrawDB", + back_populates="experiment", + primaryjoin="ExperimentDB.experiment_id==DrawDB.experiment_id", + lazy="joined", + ) + clients: Mapped[list["ClientDB"]] = relationship( + "ClientDB", + back_populates="experiment", + lazy="joined", + ) + contexts: Mapped[Optional[list["ContextDB"]]] = relationship( + "ContextDB", + back_populates="experiment", + lazy="joined", + primaryjoin="and_(ExperimentDB.experiment_id==ContextDB.experiment_id, " + "ExperimentDB.exp_type=='cmab')", + ) + + __mapper_args__ = { + "polymorphic_identity": "experiment", + "polymorphic_on": "exp_type", + } + + def __repr__(self) -> str: + """ + String representation of the model + """ + return f"" + + @property + def has_contexts(self) -> bool: + """Check if this experiment type supports contexts.""" + return self.exp_type == "cmab" + + @property + def context_list(self) -> list["ContextDB"]: + """Get contexts, returning empty list if not applicable.""" + return self.contexts if self.has_contexts else [] + + def to_dict(self) -> dict: + """ + Convert the ORM object to a dictionary. + """ + return { + "experiment_id": self.experiment_id, + "user_id": self.user_id, + "workspace_id": self.workspace_id, + "name": self.name, + "description": self.description, + "sticky_assignment": self.sticky_assignment, + "auto_fail": self.auto_fail, + "auto_fail_value": self.auto_fail_value, + "auto_fail_unit": self.auto_fail_unit, + "exp_type": self.exp_type, + "prior_type": self.prior_type, + "reward_type": self.reward_type, + "created_datetime_utc": self.created_datetime_utc, + "is_active": self.is_active, + "n_trials": self.n_trials, + "last_trial_datetime_utc": self.last_trial_datetime_utc, + "arms": [arm.to_dict() for arm in self.arms], + "draws": [draw.to_dict() for draw in self.draws], + "contexts": ( + [context.to_dict() for context in self.context_list] + if self.has_contexts + else None + ), + } + + +# --- Arm model --- +class ArmDB(Base): + """ + Base model for arms. + """ + + __tablename__ = "arms" + + # IDs + arm_id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False) + experiment_id: Mapped[int] = mapped_column( + Integer, ForeignKey("experiments.experiment_id"), nullable=False + ) + user_id: Mapped[int] = mapped_column( + Integer, ForeignKey("users.user_id"), nullable=False + ) + + # Description + name: Mapped[str] = mapped_column(String(length=150), nullable=False) + description: Mapped[str] = mapped_column(String(length=500), nullable=False) + n_outcomes: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + + # Prior variables + mu_init: Mapped[Optional[float]] = mapped_column(Float, nullable=True) + sigma_init: Mapped[Optional[float]] = mapped_column(Float, nullable=True) + mu: Mapped[Optional[list[float]]] = mapped_column(ARRAY(Float), nullable=True) + covariance: Mapped[Optional[list[float]]] = mapped_column( + ARRAY(Float), nullable=True + ) + + alpha_init: Mapped[Optional[float]] = mapped_column(Float, nullable=True) + beta_init: Mapped[Optional[float]] = mapped_column(Float, nullable=True) + alpha: Mapped[Optional[float]] = mapped_column(Float, nullable=True) + beta: Mapped[Optional[float]] = mapped_column(Float, nullable=True) + + # Relationships + experiment: Mapped[ExperimentDB] = relationship( + "ExperimentDB", back_populates="arms", lazy="joined" + ) + draws: Mapped[list["DrawDB"]] = relationship( + "DrawDB", + back_populates="arm", + lazy="joined", + ) + + def to_dict(self) -> dict: + """ + Convert the ORM object to a dictionary. + """ + return { + "arm_id": self.arm_id, + "experiment_id": self.experiment_id, + "name": self.name, + "description": self.description, + "alpha": self.alpha, + "beta": self.beta, + "mu": self.mu, + "covariance": self.covariance, + "alpha_init": self.alpha_init, + "beta_init": self.beta_init, + "mu_init": self.mu_init, + "sigma_init": self.sigma_init, + "draws": [draw.to_dict() for draw in self.draws], + } + + +# --- Draw model --- +class DrawDB(Base): + """ + Base model for draws. + """ + + __tablename__ = "draws" + + # IDs + draw_id: Mapped[str] = mapped_column( + String, primary_key=True, default=lambda x: str(uuid.uuid4()) + ) + arm_id: Mapped[int] = mapped_column( + Integer, ForeignKey("arms.arm_id"), nullable=False + ) + experiment_id: Mapped[int] = mapped_column( + Integer, ForeignKey("experiments.experiment_id"), nullable=False + ) + user_id: Mapped[int] = mapped_column( + Integer, ForeignKey("users.user_id"), nullable=False + ) + client_id = Mapped[str] = mapped_column( + String, ForeignKey("clients.client_id"), nullable=True + ) + + # Logging + draw_datetime_utc: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + nullable=False, + ) + observed_datetime_utc: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=True + ) + observation_type: Mapped[ObservationType] = mapped_column( + Enum(ObservationType), nullable=True + ) + reward: Mapped[float] = mapped_column(Float, nullable=True) + context_val = Mapped[Optional[list[float]]] = mapped_column( + ARRAY(Float), nullable=True + ) + + # Relationships + arm: Mapped[ArmDB] = relationship("ArmDB", back_populates="draws", lazy="joined") + experiment: Mapped[ExperimentDB] = relationship( + "ExperimentDB", back_populates="draws", lazy="joined" + ) + client: Mapped[Optional["ClientDB"]] = relationship( + "ClientDB", + back_populates="draws", + lazy="joined", + primaryjoin="and_(DrawDB.client_id==ClientDB.client_id, ExperimentDB.sticky_assignment == True)", # noqa: E501 + ) + + def to_dict(self) -> dict: + """ + Convert the ORM object to a dictionary. + """ + return { + "draw_id": self.draw_id, + "arm_id": self.arm_id, + "experiment_id": self.experiment_id, + "user_id": self.user_id, + "client_id": self.client_id, + "draw_datetime_utc": self.draw_datetime_utc, + "observed_datetime_utc": self.observed_datetime_utc, + "observation_type": self.observation_type, + "reward": self.reward, + "context_val": self.context_val, + } + + +# --- Context model --- +class ContextDB(Base): + """ + ORM for managing context for an experiment + """ + + __tablename__ = "contexts" + + # IDs + context_id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False) + experiment_id: Mapped[int] = mapped_column( + Integer, ForeignKey("contextual_mabs.experiment_id"), nullable=False + ) + user_id: Mapped[int] = mapped_column( + Integer, ForeignKey("users.user_id"), nullable=False + ) + + # Description + name: Mapped[str] = mapped_column(String(length=150), nullable=False) + description: Mapped[str] = mapped_column(String(length=500), nullable=True) + value_type: Mapped[str] = mapped_column(String(length=50), nullable=False) + + # Relationships + experiment: Mapped[ExperimentDB] = relationship( + "ExperimentDB", back_populates="contexts", lazy="joined" + ) + + def to_dict(self) -> dict: + """ + Convert the ORM object to a dictionary. + """ + return { + "context_id": self.context_id, + "name": self.name, + "description": self.description, + "value_type": self.value_type, + } + + +# --- Client model --- +class ClientDB(Base): + """ + ORM for managing clients for an experiment + """ + + __tablename__ = "clients" + + # IDs + client_id: Mapped[str] = mapped_column( + String, primary_key=True, default=lambda x: str(uuid.uuid4()) + ) + user_id: Mapped[int] = mapped_column( + Integer, ForeignKey("users.user_id"), nullable=False + ) + experiment_id: Mapped[int] = mapped_column( + Integer, ForeignKey("experiments.experiment_id"), nullable=False + ) + workspace_id: Mapped[int] = mapped_column( + Integer, ForeignKey("workspace.workspace_id"), nullable=False + ) + + # Relationships + draws: Mapped[list[DrawDB]] = relationship( + "DrawDB", + back_populates="client", + lazy="joined", + ) + + +# --- Notifications model --- +class NotificationsDB(Base): + """ + Model for notifications. + Note: if you are updating this, you should also update models in + the background celery job + """ + + __tablename__ = "notifications" + + notification_id: Mapped[int] = mapped_column( + Integer, primary_key=True, nullable=False + ) + experiment_id: Mapped[int] = mapped_column( + Integer, ForeignKey("experiments.experiment_id"), nullable=False + ) + user_id: Mapped[int] = mapped_column( + Integer, ForeignKey("users.user_id"), nullable=False + ) + notification_type: Mapped[EventType] = mapped_column( + Enum(EventType), nullable=False + ) + notification_value: Mapped[int] = mapped_column(Integer, nullable=False) + is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) + + def to_dict(self) -> dict: + """ + Convert the model to a dictionary + """ + return { + "notification_id": self.notification_id, + "experiment_id": self.experiment_id, + "user_id": self.user_id, + "notification_type": self.notification_type, + "notification_value": self.notification_value, + "is_active": self.is_active, + } + + +# --- Experiments functions --- + + +# ---- Notifications functions ---- +async def save_notifications_to_db( + experiment_id: int, + user_id: int, + notifications: Notifications, + asession: AsyncSession, +) -> list[NotificationsDB]: + """ + Save notifications to the database + """ + notification_records = [] + + if notifications.onTrialCompletion: + notification_row = NotificationsDB( + experiment_id=experiment_id, + user_id=user_id, + notification_type=EventType.TRIALS_COMPLETED, + notification_value=notifications.numberOfTrials, + is_active=True, + ) + notification_records.append(notification_row) + + if notifications.onDaysElapsed: + notification_row = NotificationsDB( + experiment_id=experiment_id, + user_id=user_id, + notification_type=EventType.DAYS_ELAPSED, + notification_value=notifications.daysElapsed, + is_active=True, + ) + notification_records.append(notification_row) + + if notifications.onPercentBetter: + notification_row = NotificationsDB( + experiment_id=experiment_id, + user_id=user_id, + notification_type=EventType.PERCENTAGE_BETTER, + notification_value=notifications.percentBetterThreshold, + is_active=True, + ) + notification_records.append(notification_row) + + asession.add_all(notification_records) + await asession.commit() + + return notification_records + + +async def get_notifications_from_db( + experiment_id: int, user_id: int, asession: AsyncSession +) -> Sequence[NotificationsDB]: + """ + Get notifications from the database + """ + statement = ( + select(NotificationsDB) + .where(NotificationsDB.experiment_id == experiment_id) + .where(NotificationsDB.user_id == user_id) + ) + + return (await asession.execute(statement)).scalars().all() diff --git a/backend/app/experiments/schemas.py b/backend/app/experiments/schemas.py new file mode 100644 index 0000000..91c38ca --- /dev/null +++ b/backend/app/experiments/schemas.py @@ -0,0 +1,502 @@ +from enum import Enum, StrEnum +from typing import Any, Optional, Self + +import numpy as np +from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic.types import NonNegativeInt + + +# --- Enums --- +class ExperimentsEnum(StrEnum): + """ + Enum for the experiment types. + """ + + MAB = "mab" + CMAB = "cmab" + BAYESAB = "bayes_ab" + + +class EventType(StrEnum): + """Types of events that can trigger a notification""" + + DAYS_ELAPSED = "days_elapsed" + TRIALS_COMPLETED = "trials_completed" + PERCENTAGE_BETTER = "percentage_better" + + +class ObservationType(StrEnum): + """Types of observations that can be made""" + + USER = "user" # Generated by the user + AUTO = "auto" # Generated by the system + + +class AutoFailUnitType(StrEnum): + """Types of units for auto fail""" + + DAYS = "days" + HOURS = "hours" + + +class Outcome(float, Enum): + """ + Enum for the outcome of a trial. + """ + + SUCCESS = 1 + FAILURE = 0 + + +class ArmPriors(StrEnum): + """ + Enum for the prior distribution of the arm. + """ + + BETA = "beta" + NORMAL = "normal" + + def __call__(self, theta: np.ndarray, **kwargs: Any) -> np.ndarray: + """ + Return the log pdf of the input param. + """ + if self == ArmPriors.BETA: + alpha = kwargs.get("alpha", np.ones_like(theta)) + beta = kwargs.get("beta", np.ones_like(theta)) + return (alpha - 1) * np.log(theta) + (beta - 1) * np.log(1 - theta) + + elif self == ArmPriors.NORMAL: + mu = kwargs.get("mu", np.zeros_like(theta)) + covariance = kwargs.get("covariance", np.diag(np.ones_like(theta))) + inv_cov = np.linalg.inv(covariance) + x = theta - mu + return -0.5 * x @ inv_cov @ x + + +class RewardLikelihood(StrEnum): + """ + Enum for the likelihood distribution of the reward. + """ + + BERNOULLI = "binary" + NORMAL = "real-valued" + + def __call__(self, reward: np.ndarray, probs: np.ndarray) -> np.ndarray: + """ + Calculate the log likelihood of the reward. + + Parameters + ---------- + reward : The reward. + probs : The probability of the reward. + """ + if self == RewardLikelihood.NORMAL: + return -0.5 * np.sum((reward - probs) ** 2) + elif self == RewardLikelihood.BERNOULLI: + return np.sum(reward * np.log(probs) + (1 - reward) * np.log(1 - probs)) + + +class ContextType(StrEnum): + """ + Enum for the type of context. + """ + + BINARY = "binary" + REAL_VALUED = "real-valued" + + +class ContextLinkFunctions(StrEnum): + """ + Enum for the link function of the arm params and context. + """ + + NONE = "none" + LOGISTIC = "logistic" + + def __call__(self, x: np.ndarray) -> np.ndarray: + """ + Apply the link function to the input param. + + Parameters + ---------- + x : The input param. + """ + if self == ContextLinkFunctions.NONE: + return x + elif self == ContextLinkFunctions.LOGISTIC: + return 1.0 / (1.0 + np.exp(-x)) + + +# --- Schemas --- +# Notifications schema +class Notifications(BaseModel): + """ + Pydantic model for a notifications. + """ + + onTrialCompletion: bool = False + numberOfTrials: NonNegativeInt | None + onDaysElapsed: bool = False + daysElapsed: NonNegativeInt | None + onPercentBetter: bool = False + percentBetterThreshold: NonNegativeInt | None + + @model_validator(mode="after") + def validate_has_assocatiated_value(self) -> Self: + """ + Validate that the required corresponding fields have been set. + """ + if self.onTrialCompletion and ( + not self.numberOfTrials or self.numberOfTrials == 0 + ): + raise ValueError( + "numberOfTrials is required when onTrialCompletion is True" + ) + if self.onDaysElapsed and (not self.daysElapsed or self.daysElapsed == 0): + raise ValueError("daysElapsed is required when onDaysElapsed is True") + if self.onPercentBetter and ( + not self.percentBetterThreshold or self.percentBetterThreshold == 0 + ): + raise ValueError( + "percentBetterThreshold is required when onPercentBetter is True" + ) + + return self + + +class NotificationsResponse(BaseModel): + """ + Pydantic model for a response for notifications + """ + + model_config = ConfigDict(from_attributes=True) + + notification_id: int + notification_type: EventType + notification_value: NonNegativeInt + is_active: bool + + +# Arms +class Arm(BaseModel): + """ + Pydantic model for an arm. + """ + + model_config = ConfigDict(from_attributes=True) + + # Description + name: str = Field( + max_length=150, + examples=["Arm 1"], + ) + description: str = Field( + max_length=500, + examples=["This is a description of the arm."], + ) + + # Prior variables + alpha_init: Optional[float] = Field( + default=None, examples=[None, 1.0], description="Alpha parameter for Beta prior" + ) + beta_init: Optional[float] = Field( + default=None, examples=[None, 1.0], description="Beta parameter for Beta prior" + ) + mu_init: Optional[float] = Field( + default=None, + examples=[None, 0.0], + description="Mean parameter for Normal prior", + ) + sigma_init: Optional[float] = Field( + default=None, + examples=[None, 1.0], + description="Standard deviation parameter for Normal prior", + ) + + @model_validator(mode="after") + def check_values(self) -> Self: + """ + Check if the values are unique. + """ + alpha = self.alpha_init + beta = self.beta_init + sigma = self.sigma_init + if alpha is not None and alpha <= 0: + raise ValueError("Alpha must be greater than 0.") + if beta is not None and beta <= 0: + raise ValueError("Beta must be greater than 0.") + if sigma is not None and sigma <= 0: + raise ValueError("Sigma must be greater than 0.") + return self + + +class ArmResponse(Arm): + """ + Pydantic model for an response for arm creation + """ + + arm_id: int + experiment_id: int + n_outcomes: int + alpha: Optional[float] + beta: Optional[float] + mu: Optional[list[float]] + covariance: Optional[list[float]] + model_config = ConfigDict( + from_attributes=True, + ) + + +# Contexts +class Context(BaseModel): + """ + Pydantic model for a binary-valued context of the experiment. + """ + + name: str = Field( + description="Name of the context", + examples=["Context 1"], + ) + description: str = Field( + description="Description of the context", + examples=["This is a description of the context."], + ) + value_type: ContextType = Field( + description="Type of value the context can take", default=ContextType.BINARY + ) + model_config = ConfigDict(from_attributes=True) + + +class ContextResponse(Context): + """ + Pydantic model for an response for context creation + """ + + context_id: int + model_config = ConfigDict(from_attributes=True) + + +class ContextInput(BaseModel): + """ + Pydantic model for a context input + """ + + context_id: int + context_value: float + model_config = ConfigDict(from_attributes=True) + + +# Client +class Client(BaseModel): + """ + Pydantic model for a client. + """ + + model_config = ConfigDict(from_attributes=True) + + client_id: str = Field( + description="Unique identifier for the client", + examples=["client_123"], + ) + + +# Draws +class Draw(BaseModel): + """ + Pydantic model for a draw. + """ + + model_config = ConfigDict(from_attributes=True) + + # Draw info + reward: Optional[float] = Field( + description="Reward observed from the draw", + default=None, + ) + context_val: Optional[list[float]] = Field( + description="Context values associated with the draw", + default=None, + ) + + +class DrawResponse(Draw): + """ + Pydantic model for a response for draw creation + """ + + draw_id: str = Field( + description="Unique identifier for the draw", + examples=["draw_123"], + ) + draw_datetime_utc: str = Field( + description="Timestamp of when the draw was made", + examples=["2023-10-01T12:00:00Z"], + ) + observed_datetime_utc: Optional[str] = Field( + description="Timestamp of when the reward was observed", + default=None, + ) + arm: ArmResponse + client: Client + + +# Experiments +class ExperimentBase(BaseModel): + """ + Pydantic base model for an experiment. + + Note: This is a base model and should not be used directly. + Use the `Experiment` model instead. + """ + + model_config = ConfigDict(from_attributes=True) + + # Description + name: str = Field( + max_length=150, + examples=["Experiment 1"], + ) + description: str = Field( + max_length=500, + examples=["This is a description of the experiment."], + ) + + is_active: bool = True + + # Assignments config + sticky_assignment: bool = Field( + description="Whether the arm assignment is sticky or not.", + default=False, + ) + + auto_fail: bool = Field( + description=( + "Whether the experiment should fail automatically after " + "a certain period if no outcome is registered." + ), + default=False, + ) + + auto_fail_value: Optional[int] = Field( + description="The time period after which the experiment should fail.", + default=None, + ) + + auto_fail_unit: Optional[AutoFailUnitType] = Field( + description="The time unit for the auto fail period.", + default=None, + ) + + # Experiment config + exp_type: ExperimentsEnum = Field( + description="The type of experiment.", + default=ExperimentsEnum.MAB, + ) + prior_type: ArmPriors = Field( + description="The type of prior distribution for the arms.", + default=ArmPriors.BETA, + ) + reward_type: RewardLikelihood = Field( + description="The type of reward we observe from the experiment.", + default=RewardLikelihood.BERNOULLI, + ) + + +class Experiment(ExperimentBase): + """ + Pydantic model for an experiment. + """ + + # Relationships + arms: list[Arm] + contexts: Optional[list[Context]] = None + clients: Optional[list[Client]] = None + notifications = Notifications + + @model_validator(mode="after") + def auto_fail_unit_and_value_set(self) -> Self: + """ + Validate that the auto fail unit and value are set if auto fail is True. + """ + if self.auto_fail: + if ( + not self.auto_fail_value + or not self.auto_fail_unit + or self.auto_fail_value <= 0 + ): + raise ValueError( + ( + "Auto fail is enabled. " + "Please provide both auto_fail_value and auto_fail_unit." + ) + ) + return self + + @model_validator(mode="after") + def check_num_arms(self) -> Self: + """ + Validate that the experiment has at least two arms. + """ + if len(self.arms) < 2: + raise ValueError("The experiment must have at least two arms.") + if self.exp_type == ExperimentsEnum.BAYESAB and len(self.arms) > 2: + raise ValueError("Bayes AB experiments can only have two arms.") + return self + + @model_validator(mode="after") + def check_arm_missing_params(self) -> Self: + """ + Check if the arm reward type is same as the experiment reward type. + """ + prior_type = self.prior_type + arms = self.arms + + prior_params = { + ArmPriors.BETA: ("alpha_init", "beta_init"), + ArmPriors.NORMAL: ("mu_init", "sigma_init"), + } + + for arm in arms: + arm_dict = arm.model_dump() + if prior_type in prior_params: + missing_params = [] + for param in prior_params[prior_type]: + if param not in arm_dict.keys(): + missing_params.append(param) + elif arm_dict[param] is None: + missing_params.append(param) + + if missing_params: + val = prior_type.value + raise ValueError(f"{val} prior needs {','.join(missing_params)}.") + return self + + @model_validator(mode="after") + def check_prior_reward_type_combo(self) -> Self: + """ + Validate that the prior and reward type combination is allowed. + """ + if self.prior_type == ArmPriors.BETA: + if not self.reward_type == RewardLikelihood.BERNOULLI: + raise ValueError( + "Beta prior can only be used with Bernoulli reward type." + ) + + return self + + +class ExperimentResponse(ExperimentBase): + """ + Pydantic model for a response for experiment creation + """ + + experiment_id: int + n_trials: int + last_trial_datetime_utc: Optional[str] = None + + arms: list[ArmResponse] + contexts: Optional[list[ContextResponse]] = None + clients: Optional[list[Client]] = None + notifications: NotificationsResponse + + model_config = ConfigDict(from_attributes=True) From 8b52b1c5efa8d56c26f56b136c10165387a7b93f Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Mon, 26 May 2025 17:17:33 +0300 Subject: [PATCH 29/74] WIP: models and schemas --- backend/app/experiments/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/app/experiments/models.py b/backend/app/experiments/models.py index ffcdc9e..9176c3f 100644 --- a/backend/app/experiments/models.py +++ b/backend/app/experiments/models.py @@ -104,8 +104,8 @@ class ExperimentDB(Base): "ContextDB", back_populates="experiment", lazy="joined", - primaryjoin="and_(ExperimentDB.experiment_id==ContextDB.experiment_id, " - "ExperimentDB.exp_type=='cmab')", + primaryjoin="and_(ExperimentDB.experiment_id==ContextDB.experiment_id," + + "ExperimentDB.exp_type=='cmab')", ) __mapper_args__ = { From c0aa9c48a4611e1a94fc80cf290605d26f30c4e3 Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Mon, 26 May 2025 18:08:40 +0300 Subject: [PATCH 30/74] Squashed commit of the following: commit b02ed750b265176faef41c3a4c12ed60439a2e2d Author: Aadhil Ahamed Date: Sat May 24 21:53:07 2025 +0530 Fix local development installation (#50) commit c928e08488cbfd09315f7077f3fcc79573ce0fd1 Author: Lakshay <116358226+lakshaydahiya67@users.noreply.github.com> Date: Tue May 13 14:45:08 2025 +0530 Fixes #44 (#46) --- Makefile | 3 +++ backend/Dockerfile | 11 ++++++++++- deployment/docker-compose/template.backend.env | 2 +- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index b4e7509..e468d4f 100644 --- a/Makefile +++ b/Makefile @@ -27,6 +27,9 @@ fresh-env: pip install psycopg2-binary==2.9.9; \ fi + @echo "Installing frontend dependencies..." + cd frontend && npm install + # --- Local development commands --- setup-dev: setup-redis setup-db diff --git a/backend/Dockerfile b/backend/Dockerfile index bb6e8a1..ed78d43 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -9,7 +9,10 @@ ARG HOME_DIR=/usr/src/${NAME} # Set up the build environment RUN apt-get update && apt-get install -y --no-install-recommends \ - gcc libpq-dev python3-dev dos2unix \ + gcc \ + libpq-dev \ + python3-dev \ + dos2unix \ && rm -rf /var/lib/apt/lists/* # Set up the home directory and permissions @@ -55,6 +58,12 @@ COPY --from=build /tmp/prometheus /tmp/prometheus ENV PYTHONPATH="/usr/local/lib/python3.12/site-packages:${HOME_DIR}" ENV PORT=${PORT} +# Install additional dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + libpq5 \ + cron \ + && rm -rf /var/lib/apt/lists/* + # Set working directory and expose the port WORKDIR ${HOME_DIR} EXPOSE ${PORT} diff --git a/deployment/docker-compose/template.backend.env b/deployment/docker-compose/template.backend.env index 8c2efae..3e7f9de 100644 --- a/deployment/docker-compose/template.backend.env +++ b/deployment/docker-compose/template.backend.env @@ -1,6 +1,6 @@ POSTGRES_USER=postgres POSTGRES_PASSWORD=postgres #pragma: allowlist secret -POSTGRES_HOST=localhost +POSTGRES_HOST=relational_db POSTGRES_PORT=5432 POSTGRES_DB=postgres From b062c7cef2f529de3998f5baf37305a1a5da6dc8 Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Mon, 26 May 2025 18:44:00 +0300 Subject: [PATCH 31/74] fix tests --- backend/tests/conftest.py | 61 ------------------------- backend/tests/test_auto_fail.py | 36 --------------- backend/tests/test_mabs.py | 21 --------- backend/tests/test_messages.py | 21 --------- backend/tests/test_notifications_job.py | 21 --------- 5 files changed, 160 deletions(-) diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 283eca2..7e6acaf 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -179,67 +179,6 @@ def regular_user(client: TestClient, db_session: Session) -> Generator: db_session.delete(regular_user) db_session.commit() - # Create user workspace relationship - user_workspace = UserWorkspaceDB( - user_id=regular_user.user_id, - workspace_id=default_workspace.workspace_id, - user_role=UserRoles.ADMIN, - default_workspace=True, - created_datetime_utc=datetime.now(UTC), - updated_datetime_utc=datetime.now(UTC), - ) - - db_session.add(user_workspace) - db_session.commit() - - yield regular_user.user_id, unique_username, unique_api_key - - # Clean up - need to handle foreign key relationships properly - try: - # 1. Clean up pending invitations that reference this user as inviter - db_session.execute( - text( - "DELETE FROM pending_invitations WHERE inviter_id = " - f"{regular_user.user_id}" - ) - ) - db_session.commit() - - # 2. Clean up API key rotation history records that reference this user - db_session.execute( - text( - "DELETE FROM api_key_rotation_history WHERE rotated_by_user_id = " - f"{regular_user.user_id}" - ) - ) - db_session.commit() - - # 3. Remove the user-workspace relationship - db_session.query(UserWorkspaceDB).filter( - UserWorkspaceDB.user_id == regular_user.user_id - ).delete() - db_session.commit() - - # 4. Remove the reference from workspace.api_key_rotated_by_user_id - db_session.query(WorkspaceDB).filter( - WorkspaceDB.api_key_rotated_by_user_id == regular_user.user_id - ).update({WorkspaceDB.api_key_rotated_by_user_id: None}) - db_session.commit() - - # 5. Now delete the workspace - db_session.query(WorkspaceDB).filter( - WorkspaceDB.workspace_name == f"{unique_username}'s Workspace" - ).delete() - db_session.commit() - - # 6. Finally delete the user - db_session.delete(regular_user) - db_session.commit() - except Exception as e: - # Log the error but don't fail the test - print(f"Error during cleanup: {e}") - db_session.rollback() - @pytest.fixture(scope="session") def user1(client: TestClient, db_session: Session) -> Generator: diff --git a/backend/tests/test_auto_fail.py b/backend/tests/test_auto_fail.py index 82cf4e5..6a91def 100644 --- a/backend/tests/test_auto_fail.py +++ b/backend/tests/test_auto_fail.py @@ -1,5 +1,4 @@ import copy -import os from datetime import datetime, timedelta, timezone from typing import Generator, Literal, Type @@ -132,41 +131,6 @@ def now(cls, *arg: list) -> datetime: return mydatetime -@fixture -def admin_token(client: TestClient) -> str: - """Get an admin token for authentication""" - response = client.post( - "/login", - data={ - "username": os.environ.get("ADMIN_USERNAME", ""), - "password": os.environ.get("ADMIN_PASSWORD", ""), - }, - ) - token = response.json()["access_token"] - return token - - -@fixture -def workspace_api_key(client: TestClient, admin_token: str) -> str: - """Get the current workspace API key for testing""" - # Get the current workspace - response = client.get( - "/workspace/current", - headers={"Authorization": f"Bearer {admin_token}"}, - ) - assert response.status_code == 200 - - # Rotate the workspace API key to get a fresh one - response = client.put( - "/workspace/rotate-key", - headers={"Authorization": f"Bearer {admin_token}"}, - ) - assert response.status_code == 200 - workspace_api_key = response.json()["new_api_key"] - - return workspace_api_key - - class TestMABAutoFailJob: @fixture def create_mab_with_autofail( diff --git a/backend/tests/test_mabs.py b/backend/tests/test_mabs.py index b7959a4..e65ccb6 100644 --- a/backend/tests/test_mabs.py +++ b/backend/tests/test_mabs.py @@ -72,27 +72,6 @@ def admin_token(client: TestClient) -> str: return token -@fixture -def workspace_api_key(client: TestClient, admin_token: str, db_session: Session) -> str: - """Get the current workspace API key for testing""" - # Get the current workspace - response = client.get( - "/workspace/current", - headers={"Authorization": f"Bearer {admin_token}"}, - ) - assert response.status_code == 200 - - # Rotate the workspace API key to get a fresh one - response = client.put( - "/workspace/rotate-key", - headers={"Authorization": f"Bearer {admin_token}"}, - ) - assert response.status_code == 200 - workspace_api_key = response.json()["new_api_key"] - - return workspace_api_key - - @fixture def clean_mabs(db_session: Session) -> Generator: yield diff --git a/backend/tests/test_messages.py b/backend/tests/test_messages.py index 782ba2c..86396b4 100644 --- a/backend/tests/test_messages.py +++ b/backend/tests/test_messages.py @@ -50,27 +50,6 @@ def admin_token(client: TestClient) -> str: return token -@fixture -def workspace_api_key(client: TestClient, admin_token: str) -> str: - """Get the current workspace API key for testing""" - # Get the current workspace - response = client.get( - "/workspace/current", - headers={"Authorization": f"Bearer {admin_token}"}, - ) - assert response.status_code == 200 - - # Rotate the workspace API key to get a fresh one - response = client.put( - "/workspace/rotate-key", - headers={"Authorization": f"Bearer {admin_token}"}, - ) - assert response.status_code == 200 - workspace_api_key = response.json()["new_api_key"] - - return workspace_api_key - - @fixture def experiment_id(client: TestClient, admin_token: str) -> Generator[int, None, None]: response = client.post( diff --git a/backend/tests/test_notifications_job.py b/backend/tests/test_notifications_job.py index 6856920..8b1ef80 100644 --- a/backend/tests/test_notifications_job.py +++ b/backend/tests/test_notifications_job.py @@ -65,27 +65,6 @@ def admin_token(client: TestClient) -> str: return token -@fixture -def workspace_api_key(client: TestClient, admin_token: str) -> str: - """Get the current workspace API key for testing""" - # Get the current workspace - response = client.get( - "/workspace/current", - headers={"Authorization": f"Bearer {admin_token}"}, - ) - assert response.status_code == 200 - - # Rotate the workspace API key to get a fresh one - response = client.put( - "/workspace/rotate-key", - headers={"Authorization": f"Bearer {admin_token}"}, - ) - assert response.status_code == 200 - workspace_api_key = response.json()["new_api_key"] - - return workspace_api_key - - class TestNotificationsJob: @fixture def create_mabs_days_elapsed( From 196c3bbafd9dfb0b9ab9cf743c0b6433b380dd34 Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Tue, 27 May 2025 19:37:18 +0300 Subject: [PATCH 32/74] working version of models, schemas and routers for creating experiments --- backend/app/__init__.py | 7 +- backend/app/contextual_mab/__init__.py | 1 - backend/app/contextual_mab/models.py | 483 ------------------ backend/app/contextual_mab/observation.py | 126 ----- backend/app/contextual_mab/routers.py | 395 -------------- backend/app/contextual_mab/sampling_utils.py | 172 ------- backend/app/contextual_mab/schemas.py | 268 ---------- backend/app/experiments/models.py | 135 +++-- backend/app/experiments/routers.py | 56 ++ backend/app/experiments/schemas.py | 32 +- backend/app/models.py | 12 +- backend/app/workspaces/models.py | 5 +- backend/tests/test_cmabs.py | 452 ---------------- .../docker-compose/docker-compose-dev.yml | 12 +- deployment/docker-compose/docker-compose.yml | 2 +- 15 files changed, 201 insertions(+), 1957 deletions(-) delete mode 100644 backend/app/contextual_mab/__init__.py delete mode 100644 backend/app/contextual_mab/models.py delete mode 100644 backend/app/contextual_mab/observation.py delete mode 100644 backend/app/contextual_mab/routers.py delete mode 100644 backend/app/contextual_mab/sampling_utils.py delete mode 100644 backend/app/contextual_mab/schemas.py create mode 100644 backend/app/experiments/routers.py delete mode 100644 backend/tests/test_cmabs.py diff --git a/backend/app/__init__.py b/backend/app/__init__.py index a2f6291..67a6f7c 100644 --- a/backend/app/__init__.py +++ b/backend/app/__init__.py @@ -5,8 +5,9 @@ from fastapi.middleware.cors import CORSMiddleware from redis import asyncio as aioredis -from . import auth, bayes_ab, contextual_mab, mab, messages +from . import auth, bayes_ab, mab, messages from .config import BACKEND_ROOT_PATH, DOMAIN, REDIS_HOST +from .experiments.routers import router as experiments_router from .users.routers import ( router as users_router, ) # to avoid circular imports @@ -56,9 +57,9 @@ def create_app() -> FastAPI: expose_headers=["*"], ) - app = FastAPI(title="Experiments API", lifespan=lifespan) + app.include_router(experiments_router) app.include_router(mab.router) - app.include_router(contextual_mab.router) + # app.include_router(contextual_mab.router) app.include_router(bayes_ab.router) app.include_router(auth.router) app.include_router(users_router) diff --git a/backend/app/contextual_mab/__init__.py b/backend/app/contextual_mab/__init__.py deleted file mode 100644 index fa07d07..0000000 --- a/backend/app/contextual_mab/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .routers import router # noqa: F401 diff --git a/backend/app/contextual_mab/models.py b/backend/app/contextual_mab/models.py deleted file mode 100644 index 60cf723..0000000 --- a/backend/app/contextual_mab/models.py +++ /dev/null @@ -1,483 +0,0 @@ -from datetime import datetime, timezone -from typing import Sequence - -import numpy as np -from sqlalchemy import ( - Float, - ForeignKey, - Integer, - String, - and_, - delete, - select, -) -from sqlalchemy.dialects.postgresql import ARRAY -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import Mapped, mapped_column, relationship - -from ..models import ( - ArmBaseDB, - Base, - DrawsBaseDB, - ExperimentBaseDB, - NotificationsDB, -) -from ..schemas import ObservationType -from .schemas import ContextualBandit - - -class ContextualBanditDB(ExperimentBaseDB): - """ - ORM for managing contextual experiments. - """ - - __tablename__ = "contextual_mabs" - - experiment_id: Mapped[int] = mapped_column( - ForeignKey("experiments_base.experiment_id", ondelete="CASCADE"), - primary_key=True, - nullable=False, - ) - - arms: Mapped[list["ContextualArmDB"]] = relationship( - "ContextualArmDB", back_populates="experiment", lazy="joined" - ) - - contexts: Mapped[list["ContextDB"]] = relationship( - "ContextDB", back_populates="experiment", lazy="joined" - ) - - draws: Mapped[list["ContextualDrawDB"]] = relationship( - "ContextualDrawDB", back_populates="experiment", lazy="joined" - ) - - __mapper_args__ = {"polymorphic_identity": "contextual_mabs"} - - def to_dict(self) -> dict: - """ - Convert the ORM object to a dictionary. - """ - return { - "experiment_id": self.experiment_id, - "user_id": self.user_id, - "workspace_id": self.workspace_id, - "name": self.name, - "description": self.description, - "sticky_assignment": self.sticky_assignment, - "auto_fail": self.auto_fail, - "auto_fail_value": self.auto_fail_value, - "auto_fail_unit": self.auto_fail_unit, - "created_datetime_utc": self.created_datetime_utc, - "is_active": self.is_active, - "n_trials": self.n_trials, - "arms": [arm.to_dict() for arm in self.arms], - "contexts": [context.to_dict() for context in self.contexts], - "prior_type": self.prior_type, - "reward_type": self.reward_type, - } - - -class ContextualArmDB(ArmBaseDB): - """ - ORM for managing contextual arms of an experiment - """ - - __tablename__ = "contextual_arms" - - arm_id: Mapped[int] = mapped_column( - ForeignKey("arms_base.arm_id", ondelete="CASCADE"), - primary_key=True, - nullable=False, - ) - - # prior variables for CMAB arms - mu_init: Mapped[float] = mapped_column(Float, nullable=False) - sigma_init: Mapped[float] = mapped_column(Float, nullable=False) - mu: Mapped[list[float]] = mapped_column(ARRAY(Float), nullable=False) - covariance: Mapped[list[float]] = mapped_column(ARRAY(Float), nullable=False) - - experiment: Mapped[ContextualBanditDB] = relationship( - "ContextualBanditDB", back_populates="arms", lazy="joined" - ) - draws: Mapped[list["ContextualDrawDB"]] = relationship( - "ContextualDrawDB", back_populates="arm", lazy="joined" - ) - - __mapper_args__ = {"polymorphic_identity": "contextual_arms"} - - def to_dict(self) -> dict: - """ - Convert the ORM object to a dictionary. - """ - return { - "arm_id": self.arm_id, - "name": self.name, - "description": self.description, - "mu_init": self.mu_init, - "sigma_init": self.sigma_init, - "mu": self.mu, - "covariance": self.covariance, - "draws": [draw.to_dict() for draw in self.draws], - } - - -class ContextDB(Base): - """ - ORM for managing context for an experiment - """ - - __tablename__ = "contexts" - - context_id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False) - experiment_id: Mapped[int] = mapped_column( - Integer, ForeignKey("contextual_mabs.experiment_id"), nullable=False - ) - user_id: Mapped[int] = mapped_column( - Integer, ForeignKey("users.user_id"), nullable=False - ) - name: Mapped[str] = mapped_column(String(length=150), nullable=False) - description: Mapped[str] = mapped_column(String(length=500), nullable=True) - value_type: Mapped[str] = mapped_column(String(length=50), nullable=False) - - experiment: Mapped[ContextualBanditDB] = relationship( - "ContextualBanditDB", back_populates="contexts", lazy="joined" - ) - - def to_dict(self) -> dict: - """ - Convert the ORM object to a dictionary. - """ - return { - "context_id": self.context_id, - "name": self.name, - "description": self.description, - "value_type": self.value_type, - } - - -class ContextualDrawDB(DrawsBaseDB): - """ - ORM for managing draws of an experiment - """ - - __tablename__ = "contextual_draws" - - draw_id: Mapped[str] = mapped_column( - ForeignKey("draws_base.draw_id", ondelete="CASCADE"), - primary_key=True, - nullable=False, - ) - - context_val: Mapped[list] = mapped_column(ARRAY(Float), nullable=False) - arm: Mapped[ContextualArmDB] = relationship( - "ContextualArmDB", back_populates="draws", lazy="joined" - ) - experiment: Mapped[ContextualBanditDB] = relationship( - "ContextualBanditDB", back_populates="draws", lazy="joined" - ) - - __mapper_args__ = {"polymorphic_identity": "contextual_draws"} - - def to_dict(self) -> dict: - """ - Convert the ORM object to a dictionary. - """ - return { - "draw_id": self.draw_id, - "client_id": self.client_id, - "draw_datetime_utc": self.draw_datetime_utc, - "context_val": self.context_val, - "arm_id": self.arm_id, - "experiment_id": self.experiment_id, - "user_id": self.user_id, - "reward": self.reward, - "observation_type": self.observation_type, - "observed_datetime_utc": self.observed_datetime_utc, - } - - -async def save_contextual_mab_to_db( - experiment: ContextualBandit, - user_id: int, - workspace_id: int, - asession: AsyncSession, -) -> ContextualBanditDB: - """ - Save the experiment to the database. - """ - contexts = [ - ContextDB( - name=context.name, - description=context.description, - value_type=context.value_type.value, - user_id=user_id, - ) - for context in experiment.contexts - ] - arms = [] - for arm in experiment.arms: - arms.append( - ContextualArmDB( - name=arm.name, - description=arm.description, - mu_init=arm.mu_init, - sigma_init=arm.sigma_init, - mu=(np.ones(len(experiment.contexts)) * arm.mu_init).tolist(), - covariance=( - np.identity(len(experiment.contexts)) * arm.sigma_init - ).tolist(), - user_id=user_id, - n_outcomes=arm.n_outcomes, - ) - ) - - experiment_db = ContextualBanditDB( - name=experiment.name, - description=experiment.description, - user_id=user_id, - workspace_id=workspace_id, - is_active=experiment.is_active, - created_datetime_utc=datetime.now(timezone.utc), - n_trials=0, - arms=arms, - sticky_assignment=experiment.sticky_assignment, - auto_fail=experiment.auto_fail, - auto_fail_value=experiment.auto_fail_value, - auto_fail_unit=experiment.auto_fail_unit, - contexts=contexts, - prior_type=experiment.prior_type.value, - reward_type=experiment.reward_type.value, - ) - - asession.add(experiment_db) - await asession.commit() - await asession.refresh(experiment_db) - - return experiment_db - - -async def get_all_contextual_mabs( - workspace_id: int, - asession: AsyncSession, -) -> Sequence[ContextualBanditDB]: - """ - Get all the contextual experiments from the database for a specific workspace. - """ - statement = ( - select(ContextualBanditDB) - .where(ContextualBanditDB.workspace_id == workspace_id) - .order_by(ContextualBanditDB.experiment_id) - ) - - return (await asession.execute(statement)).unique().scalars().all() - - -async def get_contextual_mab_by_id( - experiment_id: int, workspace_id: int, asession: AsyncSession -) -> ContextualBanditDB | None: - """ - Get the contextual experiment by id from a specific workspace. - """ - condition = [ - ContextualBanditDB.experiment_id == experiment_id, - ContextualBanditDB.workspace_id == workspace_id, - ] - - statement = select(ContextualBanditDB).where(*condition) - result = await asession.execute(statement) - - return result.unique().scalar_one_or_none() - - -async def delete_contextual_mab_by_id( - experiment_id: int, workspace_id: int, asession: AsyncSession -) -> None: - """ - Delete the contextual experiment by id. - """ - await asession.execute( - delete(NotificationsDB).where(NotificationsDB.experiment_id == experiment_id) - ) - - await asession.execute( - delete(ContextualDrawDB).where(ContextualDrawDB.experiment_id == experiment_id) - ) - - await asession.execute( - delete(ContextDB).where(ContextDB.experiment_id == experiment_id) - ) - - await asession.execute( - delete(ContextualArmDB).where(ContextualArmDB.experiment_id == experiment_id) - ) - - await asession.execute( - delete(ContextualBanditDB).where( - and_( - ContextualBanditDB.workspace_id == workspace_id, - ContextualBanditDB.experiment_id == experiment_id, - ContextualBanditDB.experiment_id == ExperimentBaseDB.experiment_id, - ) - ) - ) - await asession.commit() - return None - - -async def save_contextual_obs_to_db( - draw: ContextualDrawDB, - reward: float, - asession: AsyncSession, - observation_type: ObservationType, -) -> ContextualDrawDB: - """ - Save the observation to the database. - """ - draw.reward = reward - draw.observed_datetime_utc = datetime.now(timezone.utc) - draw.observation_type = observation_type # Remove .value, pass enum directly - - await asession.commit() - await asession.refresh(draw) - - return draw - - -async def get_contextual_obs_by_experiment_arm_id( - experiment_id: int, - arm_id: int, - asession: AsyncSession, -) -> Sequence[ContextualDrawDB]: - """Get the observations for a specific arm of an experiment.""" - statement = ( - select(ContextualDrawDB) - .where( - and_( - ContextualDrawDB.experiment_id == experiment_id, - ContextualDrawDB.arm_id == arm_id, - ContextualDrawDB.reward.is_not(None), - ) - ) - .order_by(ContextualDrawDB.observed_datetime_utc) - ) - - result = await asession.execute(statement) - return result.unique().scalars().all() - - -async def get_all_contextual_obs_by_experiment_id( - experiment_id: int, - workspace_id: int, - asession: AsyncSession, -) -> Sequence[ContextualDrawDB]: - """ - Get all observations for an experiment, - verified to belong to the specified workspace. - """ - # First, verify experiment belongs to the workspace - experiment = await get_contextual_mab_by_id( - experiment_id=experiment_id, - workspace_id=workspace_id, - asession=asession, - ) - - if experiment is None: - # Return empty list if experiment doesn't exist or doesn't belong to workspace - return [] - - # Get all observations for this experiment - statement = ( - select(ContextualDrawDB) - .where( - and_( - ContextualDrawDB.experiment_id == experiment_id, - ContextualDrawDB.reward.is_not(None), - ) - ) - .order_by(ContextualDrawDB.observed_datetime_utc) - ) - - result = await asession.execute(statement) - return result.unique().scalars().all() - - -async def get_draw_by_id( - draw_id: str, asession: AsyncSession -) -> ContextualDrawDB | None: - """ - Get the draw by its ID, which should be unique across the system. - """ - statement = select(ContextualDrawDB).where(ContextualDrawDB.draw_id == draw_id) - result = await asession.execute(statement) - return result.unique().scalar_one_or_none() - - -async def get_draw_by_client_id( - client_id: str, experiment_id: int, asession: AsyncSession -) -> ContextualDrawDB | None: - """ - Get the draw by client id for a specific experiment. - """ - statement = ( - select(ContextualDrawDB) - .where(ContextualDrawDB.client_id == client_id) - .where(ContextualDrawDB.client_id.is_not(None)) - .where(ContextualDrawDB.experiment_id == experiment_id) - ) - result = await asession.execute(statement) - - return result.unique().scalars().first() - - -async def save_draw_to_db( - experiment_id: int, - arm_id: int, - context_val: list[float], - draw_id: str, - client_id: str | None, - user_id: int | None, - asession: AsyncSession, - workspace_id: int | None, -) -> ContextualDrawDB: - """ - Save the draw to the database. - """ - # If user_id is not provided but needed, get it from the experiment - if user_id is None: - if workspace_id is not None: - # Try to get experiment with workspace_id - experiment = await get_contextual_mab_by_id( - experiment_id=experiment_id, - workspace_id=workspace_id, - asession=asession, - ) - if experiment: - user_id = experiment.user_id - else: - raise ValueError(f"Experiment with id {experiment_id} not found") - else: - # Fall back to direct get if workspace_id not provided - experiment = await asession.get(ContextualBanditDB, experiment_id) - if experiment: - user_id = experiment.user_id - else: - raise ValueError(f"Experiment with id {experiment_id} not found") - - if user_id is None: - raise ValueError("User ID must be provided or derivable from experiment") - - draw_db = ContextualDrawDB( - draw_id=draw_id, - client_id=client_id, - arm_id=arm_id, - experiment_id=experiment_id, - user_id=user_id, - context_val=context_val, - draw_datetime_utc=datetime.now(timezone.utc), - ) - - asession.add(draw_db) - await asession.commit() - await asession.refresh(draw_db) - - return draw_db diff --git a/backend/app/contextual_mab/observation.py b/backend/app/contextual_mab/observation.py deleted file mode 100644 index e655bbf..0000000 --- a/backend/app/contextual_mab/observation.py +++ /dev/null @@ -1,126 +0,0 @@ -from datetime import datetime, timezone -from typing import Sequence - -import numpy as np -from fastapi.exceptions import HTTPException -from sqlalchemy.ext.asyncio import AsyncSession - -from ..schemas import ( - ObservationType, - RewardLikelihood, -) -from .models import ( - ContextualArmDB, - ContextualBanditDB, - ContextualDrawDB, - get_contextual_obs_by_experiment_arm_id, - save_contextual_obs_to_db, -) -from .sampling_utils import update_arm_params -from .schemas import ( - ContextualArmResponse, - ContextualBanditSample, -) - - -async def update_based_on_outcome( - experiment: ContextualBanditDB, - draw: ContextualDrawDB, - reward: float, - asession: AsyncSession, - observation_type: ObservationType, -) -> ContextualArmResponse: - """ - Update the arm based on the outcome of the draw. - - This is a helper function to allow `auto_fail` job to call - it as well. - """ - - update_experiment_metadata(experiment) - - arm = get_arm_from_experiment(experiment, draw.arm_id) - arm.n_outcomes += 1 - - # Ensure reward is binary for Bernoulli reward type - if experiment.reward_type == RewardLikelihood.BERNOULLI.value: - if reward not in [0, 1]: - raise HTTPException( - status_code=400, - detail="Reward must be 0 or 1 for Bernoulli reward type.", - ) - - # Get data for arm update - all_obs, contexts, rewards = await prepare_data_for_arm_update( - experiment.experiment_id, arm.arm_id, asession, draw, reward - ) - - experiment_data = ContextualBanditSample.model_validate(experiment) - mu, covariance = update_arm_params( - arm=ContextualArmResponse.model_validate(arm), - prior_type=experiment_data.prior_type, - reward_type=experiment_data.reward_type, - context=contexts, - reward=rewards, - ) - - await save_updated_data( - arm, mu, covariance, draw, reward, observation_type, asession - ) - - return ContextualArmResponse.model_validate(arm) - - -def update_experiment_metadata(experiment: ContextualBanditDB) -> None: - """Update experiment metadata with new trial information""" - experiment.n_trials += 1 - experiment.last_trial_datetime_utc = datetime.now(tz=timezone.utc) - - -def get_arm_from_experiment( - experiment: ContextualBanditDB, arm_id: int -) -> ContextualArmDB: - """Get and validate the arm from the experiment""" - arms = [a for a in experiment.arms if a.arm_id == arm_id] - if not arms: - raise HTTPException(status_code=404, detail=f"Arm with id {arm_id} not found") - return arms[0] - - -async def prepare_data_for_arm_update( - experiment_id: int, - arm_id: int, - asession: AsyncSession, - draw: ContextualDrawDB, - reward: float, -) -> tuple[Sequence[ContextualDrawDB], list[list], list[float]]: - """Prepare the data needed for updating arm parameters""" - all_obs = await get_contextual_obs_by_experiment_arm_id( - experiment_id=experiment_id, - arm_id=arm_id, - asession=asession, - ) - - rewards = [obs.reward for obs in all_obs] + [reward] - contexts = [obs.context_val for obs in all_obs] - contexts.append(draw.context_val) - - return all_obs, contexts, rewards - - -async def save_updated_data( - arm: ContextualArmDB, - mu: np.ndarray, - covariance: np.ndarray, - draw: ContextualDrawDB, - reward: float, - observation_type: ObservationType, - asession: AsyncSession, -) -> None: - """Save the updated arm and observation data""" - arm.mu = mu.tolist() - arm.covariance = covariance.tolist() - asession.add(arm) - await asession.commit() - - await save_contextual_obs_to_db(draw, reward, asession, observation_type) diff --git a/backend/app/contextual_mab/routers.py b/backend/app/contextual_mab/routers.py deleted file mode 100644 index 08eea28..0000000 --- a/backend/app/contextual_mab/routers.py +++ /dev/null @@ -1,395 +0,0 @@ -from typing import Annotated, List, Optional -from uuid import uuid4 - -from fastapi import APIRouter, Depends -from fastapi.exceptions import HTTPException -from sqlalchemy.ext.asyncio import AsyncSession - -from ..auth.dependencies import ( - authenticate_workspace_key, - get_verified_user, - require_admin_role, -) -from ..database import get_async_session -from ..models import get_notifications_from_db, save_notifications_to_db -from ..schemas import ( - ContextType, - NotificationsResponse, - ObservationType, - Outcome, -) -from ..users.models import UserDB -from ..utils import setup_logger -from ..workspaces.models import ( - WorkspaceDB, - get_user_default_workspace, -) -from .models import ( - ContextualBanditDB, - ContextualDrawDB, - delete_contextual_mab_by_id, - get_all_contextual_mabs, - get_all_contextual_obs_by_experiment_id, - get_contextual_mab_by_id, - get_draw_by_client_id, - get_draw_by_id, - save_contextual_mab_to_db, - save_draw_to_db, -) -from .observation import update_based_on_outcome -from .sampling_utils import choose_arm -from .schemas import ( - CMABDrawResponse, - CMABObservationResponse, - ContextInput, - ContextualArmResponse, - ContextualBandit, - ContextualBanditResponse, - ContextualBanditSample, -) - -router = APIRouter(prefix="/contextual_mab", tags=["Contextual Bandits"]) - -logger = setup_logger(__name__) - - -@router.post("/", response_model=ContextualBanditResponse) -async def create_contextual_mabs( - experiment: ContextualBandit, - user_db: Annotated[UserDB, Depends(require_admin_role)], - asession: AsyncSession = Depends(get_async_session), -) -> ContextualBanditResponse | HTTPException: - """ - Create a new contextual experiment with different priors for each context. - """ - workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - - if workspace_db is None: - raise HTTPException( - status_code=404, - detail="Workspace not found. Please create a workspace first.", - ) - - cmab = await save_contextual_mab_to_db( - experiment, user_db.user_id, workspace_db.workspace_id, asession - ) - notifications = await save_notifications_to_db( - experiment_id=cmab.experiment_id, - user_id=user_db.user_id, - notifications=experiment.notifications, - asession=asession, - ) - cmab_dict = cmab.to_dict() - cmab_dict["notifications"] = [n.to_dict() for n in notifications] - return ContextualBanditResponse.model_validate(cmab_dict) - - -@router.get("/", response_model=list[ContextualBanditResponse]) -async def get_contextual_mabs( - user_db: Annotated[UserDB, Depends(get_verified_user)], - asession: AsyncSession = Depends(get_async_session), -) -> list[ContextualBanditResponse]: - """ - Get details of all experiments. - """ - workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - - if workspace_db is None: - raise HTTPException( - status_code=404, - detail="Workspace not found. Please create a workspace first.", - ) - - experiments = await get_all_contextual_mabs(workspace_db.workspace_id, asession) - all_experiments = [] - for exp in experiments: - exp_dict = exp.to_dict() - exp_dict["notifications"] = [ - n.to_dict() - for n in await get_notifications_from_db( - exp.experiment_id, exp.user_id, asession - ) - ] - all_experiments.append( - ContextualBanditResponse.model_validate( - { - **exp_dict, - "notifications": [ - NotificationsResponse.model_validate(n) - for n in exp_dict["notifications"] - ], - } - ) - ) - - return all_experiments - - -@router.get("/{experiment_id}", response_model=ContextualBanditResponse) -async def get_contextual_mab( - experiment_id: int, - user_db: Annotated[UserDB, Depends(get_verified_user)], - asession: AsyncSession = Depends(get_async_session), -) -> ContextualBanditResponse | HTTPException: - """ - Get details of experiment with the provided `experiment_id`. - """ - workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - - if workspace_db is None: - raise HTTPException( - status_code=404, - detail="Workspace not found. Please create a workspace first.", - ) - - experiment = await get_contextual_mab_by_id( - experiment_id, workspace_db.workspace_id, asession - ) - if experiment is None: - raise HTTPException( - status_code=404, detail=f"Experiment with id {experiment_id} not found" - ) - - experiment_dict = experiment.to_dict() - experiment_dict["notifications"] = [ - n.to_dict() - for n in await get_notifications_from_db( - experiment.experiment_id, experiment.user_id, asession - ) - ] - return ContextualBanditResponse.model_validate(experiment_dict) - - -@router.delete("/{experiment_id}", response_model=dict) -async def delete_contextual_mab( - experiment_id: int, - user_db: Annotated[UserDB, Depends(require_admin_role)], - asession: AsyncSession = Depends(get_async_session), -) -> dict: - """ - Delete the experiment with the provided `experiment_id`. - """ - try: - workspace_db = await get_user_default_workspace( - asession=asession, user_db=user_db - ) - - if workspace_db is None: - raise HTTPException( - status_code=404, - detail="Workspace not found. Please create a workspace first.", - ) - - experiment = await get_contextual_mab_by_id( - experiment_id, workspace_db.workspace_id, asession - ) - if experiment is None: - raise HTTPException( - status_code=404, detail=f"Experiment with id {experiment_id} not found" - ) - await delete_contextual_mab_by_id( - experiment_id, workspace_db.workspace_id, asession - ) - return {"detail": f"Experiment {experiment_id} deleted successfully."} - except Exception as e: - raise HTTPException(status_code=500, detail=f"Error: {e}") from e - - -@router.post("/{experiment_id}/draw", response_model=CMABDrawResponse) -async def draw_arm( - experiment_id: int, - context: List[ContextInput], - draw_id: Optional[str] = None, - client_id: Optional[str] = None, - workspace_db: WorkspaceDB = Depends(authenticate_workspace_key), - asession: AsyncSession = Depends(get_async_session), -) -> CMABDrawResponse: - """ - Get which arm to pull next for provided experiment. - """ - # Get workspace from user context - workspace_id = workspace_db.workspace_id - - experiment = await get_contextual_mab_by_id(experiment_id, workspace_id, asession) - - if experiment is None: - raise HTTPException( - status_code=404, detail=f"Experiment with id {experiment_id} not found" - ) - - # Check context inputs - if len(experiment.contexts) != len(context): - raise HTTPException( - status_code=400, - detail="Number of contexts provided does not match the num contexts.", - ) - experiment_data = ContextualBanditSample.model_validate(experiment) - sorted_context = list(sorted(context, key=lambda x: x.context_id)) - - try: - for c_input, c_exp in zip( - sorted_context, - sorted(experiment.contexts, key=lambda x: x.context_id), - ): - if c_exp.value_type == ContextType.BINARY.value: - Outcome(c_input.context_value) - except ValueError as e: - raise HTTPException( - status_code=400, - detail=f"Invalid context value: {e}", - ) from e - - # Generate UUID if not provided - if draw_id is None: - draw_id = str(uuid4()) - - existing_draw = await get_draw_by_id(draw_id, asession) - if existing_draw: - raise HTTPException( - status_code=400, - detail=f"Draw ID {draw_id} already exists.", - ) - - # Check if sticky assignment - if experiment.sticky_assignment and not client_id: - raise HTTPException( - status_code=400, - detail="Client ID is required for sticky assignment.", - ) - - chosen_arm = choose_arm( - experiment_data, - [c.context_value for c in sorted_context], - ) - chosen_arm_id = experiment.arms[chosen_arm].arm_id - if experiment.sticky_assignment and client_id: - previous_draw = await get_draw_by_client_id( - client_id=client_id, - experiment_id=experiment.experiment_id, - asession=asession, - ) - if previous_draw: - chosen_arm_id = previous_draw.arm_id - - try: - _ = await save_draw_to_db( - experiment_id=experiment.experiment_id, - arm_id=chosen_arm_id, - context_val=[c.context_value for c in sorted_context], - draw_id=draw_id, - client_id=client_id, - user_id=None, - asession=asession, - workspace_id=workspace_id, - ) - except Exception as e: - raise HTTPException( - status_code=500, - detail=f"Error saving draw to database: {e}", - ) from e - - return CMABDrawResponse.model_validate( - { - "draw_id": draw_id, - "client_id": client_id, - "arm": ContextualArmResponse.model_validate( - [arm for arm in experiment.arms if arm.arm_id == chosen_arm_id][0] - ), - } - ) - - -@router.put("/{experiment_id}/{draw_id}/{reward}", response_model=ContextualArmResponse) -async def update_arm( - experiment_id: int, - draw_id: str, - reward: float, - workspace_db: WorkspaceDB = Depends(authenticate_workspace_key), - asession: AsyncSession = Depends(get_async_session), -) -> ContextualArmResponse: - """ - Update the arm with the provided `arm_id` for the given - `experiment_id` based on the reward. - """ - # Get workspace from user context - workspace_id = workspace_db.workspace_id - - # Get the experiment and do checks - experiment, draw = await validate_experiment_and_draw( - experiment_id, draw_id, workspace_id, asession - ) - - return await update_based_on_outcome( - experiment, draw, reward, asession, ObservationType.USER - ) - - -@router.get( - "/{experiment_id}/outcomes", - response_model=list[CMABObservationResponse], -) -async def get_outcomes( - experiment_id: int, - workspace_db: WorkspaceDB = Depends(authenticate_workspace_key), - asession: AsyncSession = Depends(get_async_session), -) -> list[CMABObservationResponse]: - """ - Get the outcomes for the experiment. - """ - # Get workspace from user context - workspace_id = workspace_db.workspace_id - - experiment = await get_contextual_mab_by_id(experiment_id, workspace_id, asession) - if not experiment: - raise HTTPException( - status_code=404, detail=f"Experiment with id {experiment_id} not found" - ) - - observations = await get_all_contextual_obs_by_experiment_id( - experiment_id=experiment.experiment_id, - workspace_id=workspace_id, - asession=asession, - ) - return [CMABObservationResponse.model_validate(obs) for obs in observations] - - -async def validate_experiment_and_draw( - experiment_id: int, - draw_id: str, - workspace_id: int, - asession: AsyncSession, -) -> tuple[ContextualBanditDB, ContextualDrawDB]: - """ - Validate that the experiment exists in the workspace - and the draw exists for that experiment. - """ - experiment = await get_contextual_mab_by_id( - experiment_id=experiment_id, - workspace_id=workspace_id, - asession=asession, - ) - if experiment is None: - raise HTTPException( - status_code=404, detail=f"Experiment with id {experiment_id} not found" - ) - - draw = await get_draw_by_id(draw_id=draw_id, asession=asession) - if draw is None: - raise HTTPException(status_code=404, detail=f"Draw with id {draw_id} not found") - - if draw.experiment_id != experiment_id: - raise HTTPException( - status_code=400, - detail=( - f"Draw with id {draw_id} does not belong " - f"to experiment with id {experiment_id}", - ), - ) - - if draw.reward is not None: - raise HTTPException( - status_code=400, - detail=f"Draw with id {draw_id} already has a reward.", - ) - - return experiment, draw diff --git a/backend/app/contextual_mab/sampling_utils.py b/backend/app/contextual_mab/sampling_utils.py deleted file mode 100644 index 03c9784..0000000 --- a/backend/app/contextual_mab/sampling_utils.py +++ /dev/null @@ -1,172 +0,0 @@ -import numpy as np -from scipy.optimize import minimize - -from ..schemas import ArmPriors, ContextLinkFunctions, RewardLikelihood -from .schemas import ContextualArmResponse, ContextualBanditSample - - -def sample_normal( - mus: list[np.ndarray], - covariances: list[np.ndarray], - context: np.ndarray, - link_function: ContextLinkFunctions, -) -> int: - """ - Thompson Sampling with normal prior. - - Parameters - ---------- - mus: mean of Normal distribution for each arm - covariances: covariance matrix of Normal distribution for each arm - context: context vector - link_function: link function for the context - """ - samples = np.array( - [ - np.random.multivariate_normal(mean=mu, cov=cov) - for mu, cov in zip(mus, covariances) - ] - ).reshape(-1, len(context)) - probs = link_function(samples @ context) - return int(probs.argmax()) - - -def update_arm_normal( - current_mu: np.ndarray, - current_covariance: np.ndarray, - reward: float, - context: np.ndarray, - sigma_llhood: float, -) -> tuple[np.ndarray, np.ndarray]: - """ - Update the mean and covariance of the normal distribution. - - Parameters - ---------- - current_mu : The mean of the normal distribution. - current_covariance : The covariance matrix of the normal distribution. - reward : The reward of the arm. - context : The context vector. - sigma_llhood : The stddev of the likelihood. - """ - new_covariance_inv = ( - np.linalg.inv(current_covariance) + (context.T @ context) / sigma_llhood**2 - ) - new_covariance = np.linalg.inv(new_covariance_inv) - - new_mu = new_covariance @ ( - np.linalg.inv(current_covariance) @ current_mu - + context * reward / sigma_llhood**2 - ) - return new_mu, new_covariance - - -def update_arm_laplace( - current_mu: np.ndarray, - current_covariance: np.ndarray, - reward: np.ndarray, - context: np.ndarray, - link_function: ContextLinkFunctions, - reward_likelihood: RewardLikelihood, - prior_type: ArmPriors, -) -> tuple[np.ndarray, np.ndarray]: - """ - Update the mean and covariance using the Laplace approximation. - - Parameters - ---------- - current_mu : The mean of the normal distribution. - current_covariance : The covariance matrix of the normal distribution. - reward : The list of rewards for the arm. - context : The list of contexts for the arm. - link_function : The link function for parameters to rewards. - reward_likelihood : The likelihood function of the reward. - prior_type : The prior type of the arm. - """ - - def objective(theta: np.ndarray) -> float: - """ - Objective function for the Laplace approximation. - - Parameters - ---------- - theta : The parameters of the arm. - """ - # Log prior - log_prior = prior_type(theta, mu=current_mu, covariance=current_covariance) - - # Log likelihood - log_likelihood = reward_likelihood(reward, link_function(context @ theta)) - - return -log_prior - log_likelihood - - result = minimize(objective, current_mu, method="L-BFGS-B", hess="2-point") - new_mu = result.x - covariance = result.hess_inv.todense() # type: ignore - - new_covariance = 0.5 * (covariance + covariance.T) - return new_mu, new_covariance.astype(np.float64) - - -def choose_arm(experiment: ContextualBanditSample, context: list[float]) -> int: - """ - Choose the arm with the highest probability. - - Parameters - ---------- - experiment : The experiment object. - context : The context vector. - """ - link_function = ( - ContextLinkFunctions.NONE - if experiment.reward_type == RewardLikelihood.NORMAL - else ContextLinkFunctions.LOGISTIC - ) - return sample_normal( - mus=[np.array(arm.mu) for arm in experiment.arms], - covariances=[np.array(arm.covariance) for arm in experiment.arms], - context=np.array(context), - link_function=link_function, - ) - - -def update_arm_params( - arm: ContextualArmResponse, - prior_type: ArmPriors, - reward_type: RewardLikelihood, - reward: list, - context: list, -) -> tuple[np.ndarray, np.ndarray]: - """ - Update the arm parameters. - - Parameters - ---------- - arm : The arm object. - prior_type : The prior type of the arm. - reward_type : The reward type of the arm. - reward : All rewards for the arm. - context : All context vectors for the arm. - """ - if (prior_type == ArmPriors.NORMAL) and (reward_type == RewardLikelihood.NORMAL): - return update_arm_normal( - current_mu=np.array(arm.mu), - current_covariance=np.array(arm.covariance), - reward=reward[-1], - context=np.array(context[-1]), - sigma_llhood=1.0, # TODO: need to implement likelihood stddev - ) - elif (prior_type == ArmPriors.NORMAL) and ( - reward_type == RewardLikelihood.BERNOULLI - ): - return update_arm_laplace( - current_mu=np.array(arm.mu), - current_covariance=np.array(arm.covariance), - reward=np.array(reward), - context=np.array(context), - link_function=ContextLinkFunctions.LOGISTIC, - reward_likelihood=RewardLikelihood.BERNOULLI, - prior_type=ArmPriors.NORMAL, - ) - else: - raise ValueError("Prior and reward type combination is not supported.") diff --git a/backend/app/contextual_mab/schemas.py b/backend/app/contextual_mab/schemas.py deleted file mode 100644 index 57baaf4..0000000 --- a/backend/app/contextual_mab/schemas.py +++ /dev/null @@ -1,268 +0,0 @@ -from datetime import datetime -from typing import Optional, Self - -from pydantic import BaseModel, ConfigDict, Field, model_validator - -from ..schemas import ( - ArmPriors, - AutoFailUnitType, - ContextType, - Notifications, - NotificationsResponse, - RewardLikelihood, - allowed_combos_cmab, -) - - -class Context(BaseModel): - """ - Pydantic model for a binary-valued context of the experiment. - """ - - name: str = Field( - description="Name of the context", - examples=["Context 1"], - ) - description: str = Field( - description="Description of the context", - examples=["This is a description of the context."], - ) - value_type: ContextType = Field( - description="Type of value the context can take", default=ContextType.BINARY - ) - model_config = ConfigDict(from_attributes=True) - - -class ContextResponse(Context): - """ - Pydantic model for an response for context creation - """ - - context_id: int - model_config = ConfigDict(from_attributes=True) - - -class ContextInput(BaseModel): - """ - Pydantic model for a context input - """ - - context_id: int - context_value: float - model_config = ConfigDict(from_attributes=True) - - -class ContextualArm(BaseModel): - """ - Pydantic model for a contextual arm of the experiment. - """ - - name: str = Field( - max_length=150, - examples=["Arm 1"], - ) - description: str = Field( - max_length=500, - examples=["This is a description of the arm."], - ) - - mu_init: float = Field( - default=0.0, - examples=[0.0, 1.2, 5.7], - description="Mean parameter for Normal prior", - ) - - sigma_init: float = Field( - default=1.0, - examples=[1.0, 0.5, 2.0], - description="Standard deviation parameter for Normal prior", - ) - n_outcomes: Optional[int] = Field( - default=0, - description="Number of outcomes for the arm", - examples=[0, 10, 15], - ) - - @model_validator(mode="after") - def check_values(self) -> Self: - """ - Check if the values are unique and set new attributes. - """ - sigma = self.sigma_init - if sigma is not None and sigma <= 0: - raise ValueError("Std dev must be greater than 0.") - return self - - model_config = ConfigDict(from_attributes=True) - - -class ContextualArmResponse(ContextualArm): - """ - Pydantic model for an response for contextual arm creation - """ - - arm_id: int - mu: list[float] - covariance: list[list[float]] - - model_config = ConfigDict(from_attributes=True) - - -class ContextualBanditBase(BaseModel): - """ - Pydantic model for a contextual experiment - Base model. - Note: Do not use this model directly. Use ContextualBandit instead. - """ - - name: str = Field( - max_length=150, - examples=["Experiment 1"], - ) - - description: str = Field( - max_length=500, - examples=["This is a description of the experiment."], - ) - - sticky_assignment: bool = Field( - description="Whether the arm assignment is sticky or not.", - default=False, - ) - - auto_fail: bool = Field( - description=( - "Whether the experiment should fail automatically after " - "a certain period if no outcome is registered." - ), - default=False, - ) - - auto_fail_value: Optional[int] = Field( - description="The time period after which the experiment should fail.", - default=None, - ) - - auto_fail_unit: Optional[AutoFailUnitType] = Field( - description="The time unit for the auto fail period.", - default=None, - ) - - reward_type: RewardLikelihood = Field( - description="The type of reward we observe from the experiment.", - default=RewardLikelihood.BERNOULLI, - ) - - prior_type: ArmPriors = Field( - description="The type of prior distribution for the arms.", - default=ArmPriors.NORMAL, - ) - - is_active: bool = True - - model_config = ConfigDict(from_attributes=True) - - -class ContextualBandit(ContextualBanditBase): - """ - Pydantic model for a contextual experiment. - """ - - arms: list[ContextualArm] - contexts: list[Context] - notifications: Notifications - - @model_validator(mode="after") - def auto_fail_unit_and_value_set(self) -> Self: - """ - Validate that the auto fail unit and value are set if auto fail is True. - """ - if self.auto_fail: - if ( - not self.auto_fail_value - or not self.auto_fail_unit - or self.auto_fail_value <= 0 - ): - raise ValueError( - ( - "Auto fail is enabled. " - "Please provide both auto_fail_value and auto_fail_unit." - ) - ) - return self - - @model_validator(mode="after") - def arms_at_least_two(self) -> Self: - """ - Validate that the experiment has at least two arms. - """ - if len(self.arms) < 2: - raise ValueError("The experiment must have at least two arms.") - return self - - @model_validator(mode="after") - def check_prior_reward_type_combo(self) -> Self: - """ - Validate that the prior and reward type combination is allowed. - """ - - if (self.prior_type, self.reward_type) not in allowed_combos_cmab: - raise ValueError("Prior and reward type combo not supported.") - return self - - model_config = ConfigDict(from_attributes=True) - - -class ContextualBanditResponse(ContextualBanditBase): - """ - Pydantic model for an response for contextual experiment creation. - Returns the id of the experiment, the arms and the contexts - """ - - experiment_id: int - workspace_id: int - arms: list[ContextualArmResponse] - contexts: list[ContextResponse] - notifications: list[NotificationsResponse] - created_datetime_utc: datetime - last_trial_datetime_utc: Optional[datetime] = None - n_trials: int - - model_config = ConfigDict(from_attributes=True) - - -class ContextualBanditSample(ContextualBanditBase): - """ - Pydantic model for a contextual experiment sample. - """ - - experiment_id: int - arms: list[ContextualArmResponse] - contexts: list[ContextResponse] - - -class CMABObservationResponse(BaseModel): - """ - Pydantic model for an response for contextual observation creation - """ - - arm_id: int - reward: float - context_val: list[float] - - draw_id: str - client_id: str | None - observed_datetime_utc: datetime - - model_config = ConfigDict(from_attributes=True) - - -class CMABDrawResponse(BaseModel): - """ - Pydantic model for an response for contextual arm draw - """ - - draw_id: str - client_id: str | None - arm: ContextualArmResponse - - model_config = ConfigDict(from_attributes=True) diff --git a/backend/app/experiments/models.py b/backend/app/experiments/models.py index 9176c3f..b7df1c4 100644 --- a/backend/app/experiments/models.py +++ b/backend/app/experiments/models.py @@ -1,6 +1,6 @@ import uuid -from datetime import datetime -from typing import TYPE_CHECKING, Optional, Sequence +from datetime import datetime, timezone +from typing import Optional, Sequence from sqlalchemy import ( Boolean, @@ -14,25 +14,17 @@ ) from sqlalchemy.dialects.postgresql import ARRAY from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship +from sqlalchemy.orm import Mapped, mapped_column, relationship +from ..models import Base from .schemas import ( AutoFailUnitType, EventType, + Experiment, Notifications, ObservationType, ) -if TYPE_CHECKING: - from .workspaces.models import WorkspaceDB - - -# Base class for SQLAlchemy models -class Base(DeclarativeBase): - """Base class for SQLAlchemy models""" - - pass - # --- Base model for experiments --- class ExperimentDB(Base): @@ -83,9 +75,6 @@ class ExperimentDB(Base): ) # Relationships - workspace: Mapped["WorkspaceDB"] = relationship( - "WorkspaceDB", back_populates="experiments" - ) arms: Mapped[list["ArmDB"]] = relationship( "ArmDB", back_populates="experiment", lazy="joined" ) @@ -108,11 +97,6 @@ class ExperimentDB(Base): + "ExperimentDB.exp_type=='cmab')", ) - __mapper_args__ = { - "polymorphic_identity": "experiment", - "polymorphic_on": "exp_type", - } - def __repr__(self) -> str: """ String representation of the model @@ -170,12 +154,15 @@ class ArmDB(Base): # IDs arm_id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False) - experiment_id: Mapped[int] = mapped_column( - Integer, ForeignKey("experiments.experiment_id"), nullable=False - ) user_id: Mapped[int] = mapped_column( Integer, ForeignKey("users.user_id"), nullable=False ) + workspace_id: Mapped[int] = mapped_column( + Integer, ForeignKey("workspace.workspace_id"), nullable=False + ) + experiment_id: Mapped[int] = mapped_column( + Integer, ForeignKey("experiments.experiment_id"), nullable=False + ) # Description name: Mapped[str] = mapped_column(String(length=150), nullable=False) @@ -223,6 +210,7 @@ def to_dict(self) -> dict: "mu_init": self.mu_init, "sigma_init": self.sigma_init, "draws": [draw.to_dict() for draw in self.draws], + "n_outcomes": self.n_outcomes, } @@ -247,8 +235,8 @@ class DrawDB(Base): user_id: Mapped[int] = mapped_column( Integer, ForeignKey("users.user_id"), nullable=False ) - client_id = Mapped[str] = mapped_column( - String, ForeignKey("clients.client_id"), nullable=True + client_id: Mapped[str] = mapped_column( + String(length=36), ForeignKey("clients.client_id"), nullable=False ) # Logging @@ -263,7 +251,7 @@ class DrawDB(Base): Enum(ObservationType), nullable=True ) reward: Mapped[float] = mapped_column(Float, nullable=True) - context_val = Mapped[Optional[list[float]]] = mapped_column( + context_val: Mapped[Optional[list[float]]] = mapped_column( ARRAY(Float), nullable=True ) @@ -303,12 +291,12 @@ class ContextDB(Base): ORM for managing context for an experiment """ - __tablename__ = "contexts" + __tablename__ = "context" # IDs context_id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False) experiment_id: Mapped[int] = mapped_column( - Integer, ForeignKey("contextual_mabs.experiment_id"), nullable=False + Integer, ForeignKey("experiments.experiment_id"), nullable=False ) user_id: Mapped[int] = mapped_column( Integer, ForeignKey("users.user_id"), nullable=False @@ -364,6 +352,13 @@ class ClientDB(Base): back_populates="client", lazy="joined", ) + experiment: Mapped[ExperimentDB] = relationship( + "ExperimentDB", + back_populates="clients", + lazy="joined", + primaryjoin="and_(ClientDB.experiment_id==ExperimentDB.experiment_id," + + "ExperimentDB.sticky_assignment == True)", + ) # --- Notifications model --- @@ -385,6 +380,9 @@ class NotificationsDB(Base): user_id: Mapped[int] = mapped_column( Integer, ForeignKey("users.user_id"), nullable=False ) + workspace_id: Mapped[int] = mapped_column( + Integer, ForeignKey("workspace.workspace_id"), nullable=False + ) notification_type: Mapped[EventType] = mapped_column( Enum(EventType), nullable=False ) @@ -393,7 +391,7 @@ class NotificationsDB(Base): def to_dict(self) -> dict: """ - Convert the model to a dictionary + Convert the model to a dictionary. """ return { "notification_id": self.notification_id, @@ -405,7 +403,7 @@ def to_dict(self) -> dict: } -# --- Experiments functions --- +# --- ORM functions --- # ---- Notifications functions ---- @@ -469,3 +467,78 @@ async def get_notifications_from_db( ) return (await asession.execute(statement)).scalars().all() + + +# --- Experiment functions --- +async def save_experiment_to_db( + experiment: Experiment, + user_id: int, + workspace_id: int, + asession: AsyncSession, +) -> ExperimentDB: + """ + Save an experiment to the database. + """ + len_contexts = len(experiment.contexts) if experiment.contexts else 1 + contexts = None + + arms = [ + ArmDB( + user_id=user_id, + workspace_id=workspace_id, + # description + name=arm.name, + description=arm.description, + n_outcomes=0, + # prior variables + mu_init=arm.mu_init, + sigma_init=arm.sigma_init, + mu=[arm.mu_init] * len_contexts, + covariance=[arm.sigma_init] * len_contexts, + alpha_init=arm.alpha_init, + beta_init=arm.beta_init, + alpha=arm.alpha_init, + beta=arm.beta_init, + ) + for arm in experiment.arms + ] + if experiment.contexts and len_contexts > 0: + contexts = [ + ContextDB( + user_id=user_id, + name=context.name, + description=context.description, + value_type=context.value_type, + ) + for context in experiment.contexts + ] + + experiment_db = ExperimentDB( + user_id=user_id, + workspace_id=workspace_id, + # description + name=experiment.name, + description=experiment.description, + is_active=experiment.is_active, + # assignments config + sticky_assignment=experiment.sticky_assignment, + auto_fail=experiment.auto_fail, + auto_fail_value=experiment.auto_fail_value, + auto_fail_unit=experiment.auto_fail_unit, + # experiment config + exp_type=experiment.exp_type, + prior_type=experiment.prior_type, + reward_type=experiment.reward_type, + # datetime + created_datetime_utc=datetime.now(timezone.utc), + n_trials=0, + # relationships + arms=arms, + contexts=contexts, + ) + + asession.add(experiment_db) + await asession.commit() + await asession.refresh(experiment_db) + + return experiment_db diff --git a/backend/app/experiments/routers.py b/backend/app/experiments/routers.py new file mode 100644 index 0000000..d9eba78 --- /dev/null +++ b/backend/app/experiments/routers.py @@ -0,0 +1,56 @@ +from typing import Annotated + +from fastapi import APIRouter, Depends +from fastapi.exceptions import HTTPException +from sqlalchemy.ext.asyncio import AsyncSession + +from ..auth.dependencies import ( + require_admin_role, +) +from ..database import get_async_session +from ..users.models import UserDB +from ..utils import setup_logger +from ..workspaces.models import ( + get_user_default_workspace, +) +from .models import save_experiment_to_db, save_notifications_to_db +from .schemas import Experiment, ExperimentResponse + +router = APIRouter(prefix="/experiment", tags=["Experiments"]) + +logger = setup_logger(__name__) + + +@router.post("/", response_model=ExperimentResponse) +async def create_experiment( + experiment: Experiment, + user_db: Annotated[UserDB, Depends(require_admin_role)], + asession: AsyncSession = Depends(get_async_session), +) -> ExperimentResponse: + """ + Create a new experiment in the current user's workspace. + """ + workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) + + if workspace_db is None: + raise HTTPException( + status_code=404, + detail="Workspace not found. Please create a workspace first.", + ) + + experiment_db = await save_experiment_to_db( + experiment=experiment, + workspace_id=workspace_db.workspace_id, + user_id=user_db.user_id, + asession=asession, + ) + notifications = await save_notifications_to_db( + experiment_id=experiment_db.experiment_id, + user_id=user_db.user_id, + notifications=experiment.notifications, + asession=asession, + ) + + experiment_dict = experiment_db.to_dict() + experiment_dict["notifications"] = [n.to_dict() for n in notifications] + return ExperimentResponse.model_validate(experiment_dict) diff --git a/backend/app/experiments/schemas.py b/backend/app/experiments/schemas.py index 91c38ca..e978985 100644 --- a/backend/app/experiments/schemas.py +++ b/backend/app/experiments/schemas.py @@ -1,5 +1,5 @@ from enum import Enum, StrEnum -from typing import Any, Optional, Self +from typing import Any, List, Optional, Self, Union import numpy as np from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -238,10 +238,11 @@ class ArmResponse(Arm): arm_id: int experiment_id: int n_outcomes: int - alpha: Optional[float] - beta: Optional[float] - mu: Optional[list[float]] - covariance: Optional[list[float]] + alpha: Optional[Union[float, None]] + beta: Optional[Union[float, None]] + mu: Optional[List[Union[float, None]]] + covariance: Optional[List[Union[float, None]]] + draws: Optional[List[Union[float, None]]] model_config = ConfigDict( from_attributes=True, ) @@ -409,9 +410,9 @@ class Experiment(ExperimentBase): # Relationships arms: list[Arm] + notifications: Notifications contexts: Optional[list[Context]] = None clients: Optional[list[Client]] = None - notifications = Notifications @model_validator(mode="after") def auto_fail_unit_and_value_set(self) -> Self: @@ -479,11 +480,26 @@ def check_prior_reward_type_combo(self) -> Self: if self.prior_type == ArmPriors.BETA: if not self.reward_type == RewardLikelihood.BERNOULLI: raise ValueError( - "Beta prior can only be used with Bernoulli reward type." + "Beta prior can only be used with binary-valued rewards." ) return self + @model_validator(mode="after") + def check_contexts(self) -> Self: + """ + Validate that the contexts inputs are valid. + """ + if self.exp_type == "cmab" and not self.contexts: + raise ValueError("Contextual MAB experiments require at least one context.") + if self.exp_type != "cmab" and self.contexts: + raise ValueError( + "Contexts are only applicable for contextual MAB experiments." + ) + return self + + model_config = ConfigDict(from_attributes=True) + class ExperimentResponse(ExperimentBase): """ @@ -495,8 +511,8 @@ class ExperimentResponse(ExperimentBase): last_trial_datetime_utc: Optional[str] = None arms: list[ArmResponse] + notifications: list[NotificationsResponse] contexts: Optional[list[ContextResponse]] = None clients: Optional[list[Client]] = None - notifications: NotificationsResponse model_config = ConfigDict(from_attributes=True) diff --git a/backend/app/models.py b/backend/app/models.py index 097aa2b..4b565b6 100644 --- a/backend/app/models.py +++ b/backend/app/models.py @@ -1,6 +1,6 @@ import uuid from datetime import datetime -from typing import TYPE_CHECKING, Sequence +from typing import Sequence from sqlalchemy import ( Boolean, @@ -13,13 +13,10 @@ select, ) from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column from .schemas import AutoFailUnitType, EventType, Notifications, ObservationType -if TYPE_CHECKING: - from .workspaces.models import WorkspaceDB - class Base(DeclarativeBase): """Base class for SQLAlchemy models""" @@ -66,9 +63,6 @@ class ExperimentBaseDB(Base): last_trial_datetime_utc: Mapped[datetime] = mapped_column( DateTime(timezone=True), nullable=True ) - workspace: Mapped["WorkspaceDB"] = relationship( - "WorkspaceDB", back_populates="experiments" - ) __mapper_args__ = { "polymorphic_identity": "experiment", @@ -161,7 +155,7 @@ class NotificationsDB(Base): the background celery job """ - __tablename__ = "notifications" + __tablename__ = "notifications_db" notification_id: Mapped[int] = mapped_column( Integer, primary_key=True, nullable=False diff --git a/backend/app/workspaces/models.py b/backend/app/workspaces/models.py index b79aeab..52146cc 100644 --- a/backend/app/workspaces/models.py +++ b/backend/app/workspaces/models.py @@ -19,7 +19,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Mapped, mapped_column, relationship -from ..models import Base, ExperimentBaseDB +from ..models import Base from ..users.exceptions import UserNotFoundError from ..users.schemas import UserCreate from .schemas import UserCreateWithCode, UserRoles @@ -77,9 +77,6 @@ class WorkspaceDB(Base): workspace_id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False) workspace_name: Mapped[str] = mapped_column(String, nullable=False, unique=True) is_default: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) - experiments: Mapped[list["ExperimentBaseDB"]] = relationship( - "ExperimentBaseDB", back_populates="workspace", cascade="all, delete-orphan" - ) pending_invitations: Mapped[list["PendingInvitationDB"]] = relationship( "PendingInvitationDB", back_populates="workspace", cascade="all, delete-orphan" diff --git a/backend/tests/test_cmabs.py b/backend/tests/test_cmabs.py deleted file mode 100644 index d9b6ed0..0000000 --- a/backend/tests/test_cmabs.py +++ /dev/null @@ -1,452 +0,0 @@ -import copy -import os -from typing import Generator - -import numpy as np -from fastapi.testclient import TestClient -from pytest import FixtureRequest, fixture, mark -from sqlalchemy.orm import Session - -from backend.app.contextual_mab.models import ( - ContextDB, - ContextualArmDB, - ContextualBanditDB, -) -from backend.app.models import NotificationsDB - -base_normal_payload = { - "name": "Test", - "description": "Test description", - "prior_type": "normal", - "reward_type": "real-valued", - "sticky_assignment": False, - "arms": [ - { - "name": "arm 1", - "description": "arm 1 description", - "mu_init": 0, - "sigma_init": 1, - }, - { - "name": "arm 2", - "description": "arm 2 description", - "mu_init": 0, - "sigma_init": 1, - }, - ], - "contexts": [ - { - "name": "Context 1", - "description": "context 1 description", - "value_type": "binary", - }, - { - "name": "Context 2", - "description": "context 2 description", - "value_type": "real-valued", - }, - ], - "notifications": { - "onTrialCompletion": True, - "numberOfTrials": 2, - "onDaysElapsed": False, - "daysElapsed": 3, - "onPercentBetter": False, - "percentBetterThreshold": 5, - }, -} - -base_binary_normal_payload = base_normal_payload.copy() -base_binary_normal_payload["reward_type"] = "binary" - - -@fixture -def admin_token(client: TestClient) -> str: - response = client.post( - "/login", - data={ - "username": os.environ.get("ADMIN_USERNAME", ""), - "password": os.environ.get("ADMIN_PASSWORD", ""), - }, - ) - token = response.json()["access_token"] - return token - - -@fixture -def clean_cmabs(db_session: Session) -> Generator: - yield - db_session.query(NotificationsDB).delete() - db_session.query(ContextualArmDB).delete() - db_session.query(ContextDB).delete() - db_session.query(ContextualBanditDB).delete() - db_session.commit() - - -class TestCMab: - @fixture - def create_cmab_payload(self, request: FixtureRequest) -> dict: - payload_normal: dict = copy.deepcopy(base_normal_payload) - payload_normal["arms"] = list(payload_normal["arms"]) - payload_normal["contexts"] = list(payload_normal["contexts"]) - - payload_binary_normal: dict = copy.deepcopy(base_binary_normal_payload) - payload_binary_normal["arms"] = list(payload_binary_normal["arms"]) - payload_binary_normal["contexts"] = list(payload_binary_normal["contexts"]) - - if request.param == "base_normal": - return payload_normal - if request.param == "base_binary_normal": - return payload_binary_normal - if request.param == "one_arm": - payload_normal["arms"].pop() - return payload_normal - if request.param == "no_notifications": - payload_normal["notifications"]["onTrialCompletion"] = False - return payload_normal - if request.param == "invalid_prior": - payload_normal["prior_type"] = "beta" - return payload_normal - if request.param == "invalid_reward": - payload_normal["reward_type"] = "invalid" - return payload_normal - if request.param == "invalid_sigma": - payload_normal["arms"][0]["sigma_init"] = 0 - return payload_normal - if request.param == "with_sticky_assignment": - payload_normal["sticky_assignment"] = True - return payload_normal - - else: - raise ValueError("Invalid parameter") - - @mark.parametrize( - "create_cmab_payload, expected_response", - [ - ("base_normal", 200), - ("base_binary_normal", 200), - ("one_arm", 422), - ("no_notifications", 200), - ("invalid_prior", 422), - ("invalid_reward", 422), - ("invalid_sigma", 422), - ], - indirect=["create_cmab_payload"], - ) - def test_create_cmab( - self, - create_cmab_payload: dict, - client: TestClient, - expected_response: int, - admin_token: str, - clean_cmabs: None, - ) -> None: - response = client.post( - "/contextual_mab", - json=create_cmab_payload, - headers={"Authorization": f"Bearer {admin_token}"}, - ) - - assert response.status_code == expected_response - - @fixture - def create_cmabs( - self, - client: TestClient, - admin_token: str, - request: FixtureRequest, - create_cmab_payload: dict, - ) -> Generator: - cmabs = [] - n_cmabs = request.param if hasattr(request, "param") else 1 - for _ in range(n_cmabs): - response = client.post( - "/contextual_mab", - json=create_cmab_payload, - headers={"Authorization": f"Bearer {admin_token}"}, - ) - cmabs.append(response.json()) - yield cmabs - for cmab in cmabs: - client.delete( - f"/contextual_mab/{cmab['experiment_id']}", - headers={"Authorization": f"Bearer {admin_token}"}, - ) - - @mark.parametrize( - "create_cmabs, n_expected, create_cmab_payload", - [(0, 0, "base_normal"), (2, 2, "base_normal"), (5, 5, "base_normal")], - indirect=["create_cmabs", "create_cmab_payload"], - ) - def test_get_all_cmabs( - self, - client: TestClient, - admin_token: str, - n_expected: int, - create_cmab_payload: dict, - create_cmabs: list, - ) -> None: - response = client.get( - "/contextual_mab", headers={"Authorization": f"Bearer {admin_token}"} - ) - assert response.status_code == 200 - assert len(response.json()) == n_expected - - @mark.parametrize( - "create_cmabs, expected_response, create_cmab_payload", - [(0, 404, "base_normal"), (2, 200, "base_normal")], - indirect=["create_cmabs", "create_cmab_payload"], - ) - def test_get_cmab( - self, - client: TestClient, - admin_token: str, - create_cmab_payload: dict, - create_cmabs: list, - expected_response: int, - ) -> None: - id = create_cmabs[0]["experiment_id"] if create_cmabs else 999 - - response = client.get( - f"/contextual_mab/{id}", headers={"Authorization": f"Bearer {admin_token}"} - ) - assert response.status_code == expected_response - - @mark.parametrize("create_cmab_payload", ["base_normal"], indirect=True) - def test_draw_arm_draw_id_provided( - self, - client: TestClient, - create_cmabs: list, - create_cmab_payload: dict, - workspace_api_key: str, - ) -> None: - id = create_cmabs[0]["experiment_id"] - response = client.post( - f"/contextual_mab/{id}/draw", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - params={"draw_id": "test_draw_id"}, - json=[ - {"context_id": 1, "context_value": 0}, - {"context_id": 2, "context_value": 0.5}, - ], - ) - assert response.status_code == 200 - assert response.json()["draw_id"] == "test_draw_id" - - @mark.parametrize("create_cmab_payload", ["base_normal"], indirect=True) - def test_draw_arm_no_draw_id_provided( - self, - client: TestClient, - create_cmabs: list, - create_cmab_payload: dict, - workspace_api_key: str, - ) -> None: - id = create_cmabs[0]["experiment_id"] - response = client.post( - f"/contextual_mab/{id}/draw", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - json=[ - {"context_id": 1, "context_value": 0}, - {"context_id": 2, "context_value": 0.5}, - ], - ) - assert response.status_code == 200 - assert len(response.json()["draw_id"]) == 36 - - @mark.parametrize( - "create_cmab_payload, client_id, expected_response", - [ - ("with_sticky_assignment", None, 400), - ("with_sticky_assignment", "test_client_id", 200), - ], - indirect=["create_cmab_payload"], - ) - def test_draw_arm_sticky_assignment_client_id_provided( - self, - client: TestClient, - create_cmabs: list, - create_cmab_payload: dict, - client_id: str | None, - expected_response: int, - workspace_api_key: str, - ) -> None: - id = create_cmabs[0]["experiment_id"] - url = f"/contextual_mab/{id}/draw" - if client_id: - url += f"?client_id={client_id}" - - response = client.post( - url, - headers={"Authorization": f"Bearer {workspace_api_key}"}, - json=[ - {"context_id": 1, "context_value": 0}, - {"context_id": 2, "context_value": 0.5}, - ], - ) - assert response.status_code == expected_response - - @mark.parametrize("create_cmab_payload", ["with_sticky_assignment"], indirect=True) - def test_draw_arm_with_sticky_assignment( - self, - client: TestClient, - create_cmabs: list, - create_cmab_payload: dict, - workspace_api_key: str, - ) -> None: - id = create_cmabs[0]["experiment_id"] - arm_ids = [] - - for _ in range(10): - response = client.post( - f"/contextual_mab/{id}/draw?client_id=123", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - json=[ - {"context_id": 1, "context_value": 0}, - {"context_id": 2, "context_value": 1}, - ], - ) - arm_ids.append(response.json()["arm"]["arm_id"]) - - assert np.unique(arm_ids).size == 1 - - @mark.parametrize("create_cmab_payload", ["base_normal"], indirect=True) - def test_one_outcome_per_draw( - self, - client: TestClient, - create_cmabs: list, - create_cmab_payload: dict, - workspace_api_key: str, - ) -> None: - id = create_cmabs[0]["experiment_id"] - response = client.post( - f"/contextual_mab/{id}/draw", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - json=[ - {"context_id": 1, "context_value": 0}, - {"context_id": 2, "context_value": 0.5}, - ], - ) - assert response.status_code == 200 - draw_id = response.json()["draw_id"] - - response = client.put( - f"/contextual_mab/{id}/{draw_id}/1", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - - assert response.status_code == 200 - - response = client.put( - f"/contextual_mab/{id}/{draw_id}/1", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - - assert response.status_code == 400 - - @mark.parametrize( - "n_draws, create_cmab_payload", - [(0, "base_normal"), (1, "base_normal"), (5, "base_normal")], - indirect=["create_cmab_payload"], - ) - def test_get_outcomes( - self, - client: TestClient, - create_cmabs: list, - n_draws: int, - create_cmab_payload: dict, - workspace_api_key: str, - ) -> None: - id = create_cmabs[0]["experiment_id"] - - for _ in range(n_draws): - response = client.post( - f"/contextual_mab/{id}/draw", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - json=[ - {"context_id": 1, "context_value": 0}, - {"context_id": 2, "context_value": 0.5}, - ], - ) - assert response.status_code == 200 - draw_id = response.json()["draw_id"] - response = client.put( - f"/contextual_mab/{id}/{draw_id}/1", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - - response = client.get( - f"/contextual_mab/{id}/outcomes", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - - assert response.status_code == 200 - assert len(response.json()) == n_draws - - -class TestNotifications: - @fixture() - def create_cmab_payload(self, request: FixtureRequest) -> dict: - payload: dict = copy.deepcopy(base_normal_payload) - payload["arms"] = list(payload["arms"]) - payload["contexts"] = list(payload["contexts"]) - - match request.param: - case "base": - pass - case "daysElapsed_only": - payload["notifications"]["onTrialCompletion"] = False - payload["notifications"]["onDaysElapsed"] = True - case "trialCompletion_only": - payload["notifications"]["onTrialCompletion"] = True - case "percentBetter_only": - payload["notifications"]["onTrialCompletion"] = False - payload["notifications"]["onPercentBetter"] = True - case "all_notifications": - payload["notifications"]["onDaysElapsed"] = True - payload["notifications"]["onPercentBetter"] = True - case "no_notifications": - payload["notifications"]["onTrialCompletion"] = False - case "daysElapsed_missing": - payload["notifications"]["daysElapsed"] = 0 - payload["notifications"]["onDaysElapsed"] = True - case "trialCompletion_missing": - payload["notifications"]["numberOfTrials"] = 0 - payload["notifications"]["onTrialCompletion"] = True - case "percentBetter_missing": - payload["notifications"]["percentBetterThreshold"] = 0 - payload["notifications"]["onPercentBetter"] = True - case _: - raise ValueError("Invalid parameter") - - return payload - - @mark.parametrize( - "create_cmab_payload, expected_response", - [ - ("base", 200), - ("daysElapsed_only", 200), - ("trialCompletion_only", 200), - ("percentBetter_only", 200), - ("all_notifications", 200), - ("no_notifications", 200), - ("daysElapsed_missing", 422), - ("trialCompletion_missing", 422), - ("percentBetter_missing", 422), - ], - indirect=["create_cmab_payload"], - ) - def test_notifications( - self, - client: TestClient, - admin_token: str, - create_cmab_payload: dict, - expected_response: int, - clean_cmabs: None, - ) -> None: - response = client.post( - "/contextual_mab", - json=create_cmab_payload, - headers={"Authorization": f"Bearer {admin_token}"}, - ) - - assert response.status_code == expected_response diff --git a/deployment/docker-compose/docker-compose-dev.yml b/deployment/docker-compose/docker-compose-dev.yml index 6a7b5df..25db623 100644 --- a/deployment/docker-compose/docker-compose-dev.yml +++ b/deployment/docker-compose/docker-compose-dev.yml @@ -4,13 +4,17 @@ services: build: context: ../../backend dockerfile: Dockerfile + entrypoint: ["/bin/sh", "-c"] command: > - python -m alembic upgrade head && python add_users_to_db.py && uvicorn main:app --host 0.0.0.0 --port 8000 --reload + "python -m alembic upgrade head && + python add_users_to_db.py && + uvicorn main:app --host 0.0.0.0 --port 8000 --reload + " restart: always ports: - "8000:8000" volumes: - # - temp:/usr/src/experiment_engine_backend/temp + - temp:/usr/src/experiment_engine_backend/temp - ../../backend:/usr/src/experiment_engine_backend env_file: - .base.env @@ -65,7 +69,7 @@ services: redis: image: "redis:6.0-alpine" - ports: # Expose the port to port 6380 on the host machine for debugging + ports: - "6380:6379" restart: always @@ -73,4 +77,4 @@ volumes: db_volume: caddy_data: caddy_config: - # temp: + temp: diff --git a/deployment/docker-compose/docker-compose.yml b/deployment/docker-compose/docker-compose.yml index cd1dab0..8e2b144 100644 --- a/deployment/docker-compose/docker-compose.yml +++ b/deployment/docker-compose/docker-compose.yml @@ -60,7 +60,7 @@ services: redis: image: "redis:6.0-alpine" - ports: # Expose the port to port 6380 on the host machine for debugging + ports: - "6380:6379" restart: always From 1cff2411f57cca267cc07342735c4f5a0d5fc2f2 Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Tue, 27 May 2025 19:50:01 +0300 Subject: [PATCH 33/74] working version of get all mabs endpoint --- backend/app/experiments/models.py | 14 ++++++++ backend/app/experiments/routers.py | 54 ++++++++++++++++++++++++++++-- 2 files changed, 66 insertions(+), 2 deletions(-) diff --git a/backend/app/experiments/models.py b/backend/app/experiments/models.py index b7df1c4..65d88af 100644 --- a/backend/app/experiments/models.py +++ b/backend/app/experiments/models.py @@ -542,3 +542,17 @@ async def save_experiment_to_db( await asession.refresh(experiment_db) return experiment_db + + +async def get_all_experiments_from_db( + workspace_id: int, asession: AsyncSession +) -> Sequence[ExperimentDB]: + """ + Get all experiments for a given workspace. + """ + statement = ( + select(ExperimentDB) + .where(ExperimentDB.workspace_id == workspace_id) + .order_by(ExperimentDB.created_datetime_utc.desc()) + ) + return (await asession.execute(statement)).unique().scalars().all() diff --git a/backend/app/experiments/routers.py b/backend/app/experiments/routers.py index d9eba78..fc2d51c 100644 --- a/backend/app/experiments/routers.py +++ b/backend/app/experiments/routers.py @@ -5,6 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from ..auth.dependencies import ( + get_verified_user, require_admin_role, ) from ..database import get_async_session @@ -13,8 +14,13 @@ from ..workspaces.models import ( get_user_default_workspace, ) -from .models import save_experiment_to_db, save_notifications_to_db -from .schemas import Experiment, ExperimentResponse +from .models import ( + get_all_experiments_from_db, + get_notifications_from_db, + save_experiment_to_db, + save_notifications_to_db, +) +from .schemas import Experiment, ExperimentResponse, NotificationsResponse router = APIRouter(prefix="/experiment", tags=["Experiments"]) @@ -54,3 +60,47 @@ async def create_experiment( experiment_dict = experiment_db.to_dict() experiment_dict["notifications"] = [n.to_dict() for n in notifications] return ExperimentResponse.model_validate(experiment_dict) + + +@router.get("/", response_model=list[ExperimentResponse]) +async def get_all_experiments( + user_db: Annotated[UserDB, Depends(get_verified_user)], + asession: AsyncSession = Depends(get_async_session), +) -> list[ExperimentResponse]: + """ + Retrieve all experiments for the current user's workspace. + """ + workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) + + if workspace_db is None: + raise HTTPException( + status_code=404, + detail="Workspace not found. Please create a workspace first.", + ) + + experiments = await get_all_experiments_from_db( + workspace_id=workspace_db.workspace_id, + asession=asession, + ) + + all_experiments = [] + for exp in experiments: + exp_dict = exp.to_dict() + exp_dict["notifications"] = [ + n.to_dict() + for n in await get_notifications_from_db( + experiment_id=exp.experiment_id, user_id=exp.user_id, asession=asession + ) + ] + all_experiments.append( + ExperimentResponse.model_validate( + { + **exp_dict, + "notifications": [ + NotificationsResponse(**n) for n in exp_dict["notifications"] + ], + } + ) + ) + + return all_experiments From 797b0502b9356baa997b79ac24abb64cb08557e0 Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Tue, 27 May 2025 19:57:32 +0300 Subject: [PATCH 34/74] add epxeriment type get router --- backend/app/experiments/models.py | 15 +++++++++ backend/app/experiments/routers.py | 54 +++++++++++++++++++++++++++++- 2 files changed, 68 insertions(+), 1 deletion(-) diff --git a/backend/app/experiments/models.py b/backend/app/experiments/models.py index 65d88af..78494f9 100644 --- a/backend/app/experiments/models.py +++ b/backend/app/experiments/models.py @@ -556,3 +556,18 @@ async def get_all_experiments_from_db( .order_by(ExperimentDB.created_datetime_utc.desc()) ) return (await asession.execute(statement)).unique().scalars().all() + + +async def get_all_experiment_types_from_db( + workspace_id: int, experiment_type: str, asession: AsyncSession +) -> Sequence[ExperimentDB]: + """ + Get all experiments for a given workspace. + """ + statement = ( + select(ExperimentDB) + .where(ExperimentDB.workspace_id == workspace_id) + .where(ExperimentDB.exp_type == experiment_type) + .order_by(ExperimentDB.created_datetime_utc.desc()) + ) + return (await asession.execute(statement)).unique().scalars().all() diff --git a/backend/app/experiments/routers.py b/backend/app/experiments/routers.py index fc2d51c..d8cfc8d 100644 --- a/backend/app/experiments/routers.py +++ b/backend/app/experiments/routers.py @@ -15,12 +15,18 @@ get_user_default_workspace, ) from .models import ( + get_all_experiment_types_from_db, get_all_experiments_from_db, get_notifications_from_db, save_experiment_to_db, save_notifications_to_db, ) -from .schemas import Experiment, ExperimentResponse, NotificationsResponse +from .schemas import ( + Experiment, + ExperimentResponse, + ExperimentsEnum, + NotificationsResponse, +) router = APIRouter(prefix="/experiment", tags=["Experiments"]) @@ -104,3 +110,49 @@ async def get_all_experiments( ) return all_experiments + + +@router.get("/{experiment_type}", response_model=list[ExperimentResponse]) +async def get_all_experiments_by_type( + experiment_type: ExperimentsEnum, + user_db: Annotated[UserDB, Depends(get_verified_user)], + asession: AsyncSession = Depends(get_async_session), +) -> list[ExperimentResponse]: + """ + Retrieve all experiments for the current user's workspace. + """ + workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) + + if workspace_db is None: + raise HTTPException( + status_code=404, + detail="Workspace not found. Please create a workspace first.", + ) + + experiments = await get_all_experiment_types_from_db( + workspace_id=workspace_db.workspace_id, + experiment_type=experiment_type.value, + asession=asession, + ) + + all_experiments = [] + for exp in experiments: + exp_dict = exp.to_dict() + exp_dict["notifications"] = [ + n.to_dict() + for n in await get_notifications_from_db( + experiment_id=exp.experiment_id, user_id=exp.user_id, asession=asession + ) + ] + all_experiments.append( + ExperimentResponse.model_validate( + { + **exp_dict, + "notifications": [ + NotificationsResponse(**n) for n in exp_dict["notifications"] + ], + } + ) + ) + + return all_experiments From 1f3beb3ba69141e25ba7618b2530f6054ab8d44f Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Tue, 27 May 2025 20:50:49 +0300 Subject: [PATCH 35/74] add get routers for exp by id --- backend/app/experiments/dependencies.py | 34 +++++++++ backend/app/experiments/models.py | 14 ++++ backend/app/experiments/routers.py | 91 +++++++++++++------------ 3 files changed, 97 insertions(+), 42 deletions(-) create mode 100644 backend/app/experiments/dependencies.py diff --git a/backend/app/experiments/dependencies.py b/backend/app/experiments/dependencies.py new file mode 100644 index 0000000..bad87a8 --- /dev/null +++ b/backend/app/experiments/dependencies.py @@ -0,0 +1,34 @@ +from sqlalchemy.ext.asyncio import AsyncSession + +from .models import ExperimentDB, get_notifications_from_db +from .schemas import ExperimentResponse, NotificationsResponse + + +async def experiments_db_to_schema( + experiments_db: list[ExperimentDB], + asession: AsyncSession, +) -> list[ExperimentResponse]: + """ + Convert a list of ExperimentDB objects to a list of ExperimentResponse schemas. + """ + all_experiments = [] + for exp in experiments_db: + exp_dict = exp.to_dict() + exp_dict["notifications"] = [ + n.to_dict() + for n in await get_notifications_from_db( + experiment_id=exp.experiment_id, user_id=exp.user_id, asession=asession + ) + ] + all_experiments.append( + ExperimentResponse.model_validate( + { + **exp_dict, + "notifications": [ + NotificationsResponse(**n) for n in exp_dict["notifications"] + ], + } + ) + ) + + return all_experiments diff --git a/backend/app/experiments/models.py b/backend/app/experiments/models.py index 78494f9..f9e5bef 100644 --- a/backend/app/experiments/models.py +++ b/backend/app/experiments/models.py @@ -571,3 +571,17 @@ async def get_all_experiment_types_from_db( .order_by(ExperimentDB.created_datetime_utc.desc()) ) return (await asession.execute(statement)).unique().scalars().all() + + +async def get_experiment_by_id_from_db( + workspace_id: int, experiment_id: int, asession: AsyncSession +) -> ExperimentDB | None: + """ + Get all experiments for a given workspace. + """ + statement = ( + select(ExperimentDB) + .where(ExperimentDB.workspace_id == workspace_id) + .where(ExperimentDB.experiment_id == experiment_id) + ) + return (await asession.execute(statement)).unique().scalars().one_or_none() diff --git a/backend/app/experiments/routers.py b/backend/app/experiments/routers.py index d8cfc8d..c63c732 100644 --- a/backend/app/experiments/routers.py +++ b/backend/app/experiments/routers.py @@ -14,10 +14,11 @@ from ..workspaces.models import ( get_user_default_workspace, ) +from .dependencies import experiments_db_to_schema from .models import ( get_all_experiment_types_from_db, get_all_experiments_from_db, - get_notifications_from_db, + get_experiment_by_id_from_db, save_experiment_to_db, save_notifications_to_db, ) @@ -25,7 +26,6 @@ Experiment, ExperimentResponse, ExperimentsEnum, - NotificationsResponse, ) router = APIRouter(prefix="/experiment", tags=["Experiments"]) @@ -33,6 +33,7 @@ logger = setup_logger(__name__) +# --- POST experiments routers --- @router.post("/", response_model=ExperimentResponse) async def create_experiment( experiment: Experiment, @@ -68,6 +69,7 @@ async def create_experiment( return ExperimentResponse.model_validate(experiment_dict) +# -- GET experiment routers --- @router.get("/", response_model=list[ExperimentResponse]) async def get_all_experiments( user_db: Annotated[UserDB, Depends(get_verified_user)], @@ -89,30 +91,14 @@ async def get_all_experiments( asession=asession, ) - all_experiments = [] - for exp in experiments: - exp_dict = exp.to_dict() - exp_dict["notifications"] = [ - n.to_dict() - for n in await get_notifications_from_db( - experiment_id=exp.experiment_id, user_id=exp.user_id, asession=asession - ) - ] - all_experiments.append( - ExperimentResponse.model_validate( - { - **exp_dict, - "notifications": [ - NotificationsResponse(**n) for n in exp_dict["notifications"] - ], - } - ) - ) - + all_experiments = await experiments_db_to_schema( + experiments_db=list(experiments), + asession=asession, + ) return all_experiments -@router.get("/{experiment_type}", response_model=list[ExperimentResponse]) +@router.get("/type/{experiment_type}", response_model=list[ExperimentResponse]) async def get_all_experiments_by_type( experiment_type: ExperimentsEnum, user_db: Annotated[UserDB, Depends(get_verified_user)], @@ -135,24 +121,45 @@ async def get_all_experiments_by_type( asession=asession, ) - all_experiments = [] - for exp in experiments: - exp_dict = exp.to_dict() - exp_dict["notifications"] = [ - n.to_dict() - for n in await get_notifications_from_db( - experiment_id=exp.experiment_id, user_id=exp.user_id, asession=asession - ) - ] - all_experiments.append( - ExperimentResponse.model_validate( - { - **exp_dict, - "notifications": [ - NotificationsResponse(**n) for n in exp_dict["notifications"] - ], - } - ) + all_experiments = await experiments_db_to_schema( + experiments_db=list(experiments), + asession=asession, + ) + return all_experiments + + +@router.get("/id/{experiment_id}", response_model=ExperimentResponse) +async def get_experiment_by_id( + experiment_id: int, + user_db: Annotated[UserDB, Depends(get_verified_user)], + asession: AsyncSession = Depends(get_async_session), +) -> ExperimentResponse: + """ + Retrieve a specific experiment by ID for the current user's workspace. + """ + workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) + + if workspace_db is None: + raise HTTPException( + status_code=404, + detail="Workspace not found. Please create a workspace first.", + ) + + experiment = await get_experiment_by_id_from_db( + workspace_id=workspace_db.workspace_id, + experiment_id=experiment_id, + asession=asession, + ) + + if not experiment: + raise HTTPException( + status_code=404, + detail="Experiment not found.", ) - return all_experiments + experiment_dict = await experiments_db_to_schema( + experiments_db=[experiment], + asession=asession, + ) + + return experiment_dict[0] From 2947b90536d3fc373cb35f711b5ee0aeb414660b Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Tue, 27 May 2025 21:33:00 +0300 Subject: [PATCH 36/74] debugging endpoints --- backend/app/experiments/dependencies.py | 5 +- backend/app/experiments/models.py | 62 ++++++++++++++++++++++++- backend/app/experiments/routers.py | 48 +++++++++++++++++++ backend/app/experiments/schemas.py | 4 +- 4 files changed, 114 insertions(+), 5 deletions(-) diff --git a/backend/app/experiments/dependencies.py b/backend/app/experiments/dependencies.py index bad87a8..b305477 100644 --- a/backend/app/experiments/dependencies.py +++ b/backend/app/experiments/dependencies.py @@ -17,7 +17,10 @@ async def experiments_db_to_schema( exp_dict["notifications"] = [ n.to_dict() for n in await get_notifications_from_db( - experiment_id=exp.experiment_id, user_id=exp.user_id, asession=asession + experiment_id=exp.experiment_id, + user_id=exp.user_id, + workspace_id=exp.workspace_id, + asession=asession, ) ] all_experiments.append( diff --git a/backend/app/experiments/models.py b/backend/app/experiments/models.py index f9e5bef..d8805e6 100644 --- a/backend/app/experiments/models.py +++ b/backend/app/experiments/models.py @@ -10,6 +10,7 @@ ForeignKey, Integer, String, + delete, select, ) from sqlalchemy.dialects.postgresql import ARRAY @@ -235,6 +236,9 @@ class DrawDB(Base): user_id: Mapped[int] = mapped_column( Integer, ForeignKey("users.user_id"), nullable=False ) + workspace_id: Mapped[int] = mapped_column( + Integer, ForeignKey("workspace.workspace_id"), nullable=False + ) client_id: Mapped[str] = mapped_column( String(length=36), ForeignKey("clients.client_id"), nullable=False ) @@ -301,6 +305,9 @@ class ContextDB(Base): user_id: Mapped[int] = mapped_column( Integer, ForeignKey("users.user_id"), nullable=False ) + workspace_id: Mapped[int] = mapped_column( + Integer, ForeignKey("workspace.workspace_id"), nullable=False + ) # Description name: Mapped[str] = mapped_column(String(length=150), nullable=False) @@ -410,6 +417,7 @@ def to_dict(self) -> dict: async def save_notifications_to_db( experiment_id: int, user_id: int, + workspace_id: int, notifications: Notifications, asession: AsyncSession, ) -> list[NotificationsDB]: @@ -422,6 +430,7 @@ async def save_notifications_to_db( notification_row = NotificationsDB( experiment_id=experiment_id, user_id=user_id, + workspace_id=workspace_id, notification_type=EventType.TRIALS_COMPLETED, notification_value=notifications.numberOfTrials, is_active=True, @@ -432,6 +441,7 @@ async def save_notifications_to_db( notification_row = NotificationsDB( experiment_id=experiment_id, user_id=user_id, + workspace_id=workspace_id, notification_type=EventType.DAYS_ELAPSED, notification_value=notifications.daysElapsed, is_active=True, @@ -442,6 +452,7 @@ async def save_notifications_to_db( notification_row = NotificationsDB( experiment_id=experiment_id, user_id=user_id, + workspace_id=workspace_id, notification_type=EventType.PERCENTAGE_BETTER, notification_value=notifications.percentBetterThreshold, is_active=True, @@ -455,7 +466,7 @@ async def save_notifications_to_db( async def get_notifications_from_db( - experiment_id: int, user_id: int, asession: AsyncSession + experiment_id: int, user_id: int, workspace_id: int, asession: AsyncSession ) -> Sequence[NotificationsDB]: """ Get notifications from the database @@ -464,6 +475,7 @@ async def get_notifications_from_db( select(NotificationsDB) .where(NotificationsDB.experiment_id == experiment_id) .where(NotificationsDB.user_id == user_id) + .where(NotificationsDB.workspace_id == workspace_id) ) return (await asession.execute(statement)).scalars().all() @@ -480,7 +492,7 @@ async def save_experiment_to_db( Save an experiment to the database. """ len_contexts = len(experiment.contexts) if experiment.contexts else 1 - contexts = None + contexts = [] arms = [ ArmDB( @@ -506,6 +518,7 @@ async def save_experiment_to_db( contexts = [ ContextDB( user_id=user_id, + workspace_id=workspace_id, name=context.name, description=context.description, value_type=context.value_type, @@ -585,3 +598,48 @@ async def get_experiment_by_id_from_db( .where(ExperimentDB.experiment_id == experiment_id) ) return (await asession.execute(statement)).unique().scalars().one_or_none() + + +async def delete_experiment_by_id_from_db( + workspace_id: int, experiment_id: int, asession: AsyncSession +) -> None: + """ + Delete an experiment by ID for a given workspace. + """ + await asession.execute( + delete(NotificationsDB) + .where(NotificationsDB.workspace_id == workspace_id) + .where(NotificationsDB.experiment_id == experiment_id) + ) + + await asession.execute( + delete(ContextDB) + .where(ContextDB.workspace_id == workspace_id) + .where(ContextDB.experiment_id == experiment_id) + ) + + await asession.execute( + delete(ClientDB) + .where(ClientDB.workspace_id == workspace_id) + .where(ClientDB.experiment_id == experiment_id) + ) + + await asession.execute( + delete(ArmDB) + .where(ArmDB.workspace_id == workspace_id) + .where(ArmDB.experiment_id == experiment_id) + ) + await asession.execute( + delete(DrawDB) + .where(DrawDB.workspace_id == workspace_id) + .where(DrawDB.experiment_id == experiment_id) + ) + + await asession.execute( + delete(ExperimentDB) + .where(ExperimentDB.workspace_id == workspace_id) + .where(ExperimentDB.experiment_id == experiment_id) + ) + + await asession.commit() + return None diff --git a/backend/app/experiments/routers.py b/backend/app/experiments/routers.py index c63c732..0656c12 100644 --- a/backend/app/experiments/routers.py +++ b/backend/app/experiments/routers.py @@ -16,6 +16,7 @@ ) from .dependencies import experiments_db_to_schema from .models import ( + delete_experiment_by_id_from_db, get_all_experiment_types_from_db, get_all_experiments_from_db, get_experiment_by_id_from_db, @@ -60,6 +61,7 @@ async def create_experiment( notifications = await save_notifications_to_db( experiment_id=experiment_db.experiment_id, user_id=user_db.user_id, + workspace_id=workspace_db.workspace_id, notifications=experiment.notifications, asession=asession, ) @@ -163,3 +165,49 @@ async def get_experiment_by_id( ) return experiment_dict[0] + + +@router.delete("/id/{experiment_id}", response_model=dict[str, str]) +async def delete_experiment_by_id( + experiment_id: int, + user_db: Annotated[UserDB, Depends(get_verified_user)], + asession: AsyncSession = Depends(get_async_session), +) -> dict[str, str]: + """ + Retrieve a specific experiment by ID for the current user's workspace. + """ + try: + workspace_db = await get_user_default_workspace( + asession=asession, user_db=user_db + ) + + if workspace_db is None: + raise HTTPException( + status_code=404, + detail="Workspace not found. Please create a workspace first.", + ) + + experiment = await get_experiment_by_id_from_db( + workspace_id=workspace_db.workspace_id, + experiment_id=experiment_id, + asession=asession, + ) + + if not experiment: + raise HTTPException( + status_code=404, + detail="Experiment not found.", + ) + + await delete_experiment_by_id_from_db( + workspace_id=workspace_db.workspace_id, + experiment_id=experiment_id, + asession=asession, + ) + + return {"message": f"Experiment with id {experiment_id} deleted successfully."} + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error: {str(e)}", + ) from e diff --git a/backend/app/experiments/schemas.py b/backend/app/experiments/schemas.py index e978985..8b7ff52 100644 --- a/backend/app/experiments/schemas.py +++ b/backend/app/experiments/schemas.py @@ -411,8 +411,8 @@ class Experiment(ExperimentBase): # Relationships arms: list[Arm] notifications: Notifications - contexts: Optional[list[Context]] = None - clients: Optional[list[Client]] = None + contexts: Optional[list[Context]] = [] + clients: Optional[list[Client]] = [] @model_validator(mode="after") def auto_fail_unit_and_value_set(self) -> Self: From 3e8aa12df2c97edc8f21c9e659bd1afa10770079 Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Tue, 27 May 2025 21:58:21 +0300 Subject: [PATCH 37/74] add bulk delete router --- backend/app/experiments/routers.py | 49 ++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/backend/app/experiments/routers.py b/backend/app/experiments/routers.py index 0656c12..63c120f 100644 --- a/backend/app/experiments/routers.py +++ b/backend/app/experiments/routers.py @@ -167,6 +167,55 @@ async def get_experiment_by_id( return experiment_dict[0] +@router.delete("/type/{experiment_type}", response_model=dict[str, str]) +async def delete_experiment_by_type( + experiment_type: ExperimentsEnum, + user_db: Annotated[UserDB, Depends(get_verified_user)], + asession: AsyncSession = Depends(get_async_session), +) -> dict[str, str]: + """ + Retrieve a specific experiment by ID for the current user's workspace. + """ + try: + workspace_db = await get_user_default_workspace( + asession=asession, user_db=user_db + ) + + if workspace_db is None: + raise HTTPException( + status_code=404, + detail="Workspace not found. Please create a workspace first.", + ) + + experiments = await get_all_experiment_types_from_db( + workspace_id=workspace_db.workspace_id, + experiment_type=experiment_type.value, + asession=asession, + ) + + if len(experiments) == 0: + raise HTTPException( + status_code=404, + detail="No experiments found.", + ) + + for exp in experiments: + await delete_experiment_by_id_from_db( + workspace_id=workspace_db.workspace_id, + experiment_id=exp.experiment_id, + asession=asession, + ) + + return { + "message": f"Experiments of type {experiment_type} deleted successfully." + } + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error: {str(e)}", + ) from e + + @router.delete("/id/{experiment_id}", response_model=dict[str, str]) async def delete_experiment_by_id( experiment_id: int, From 566ca36de0f341c6e46527415b39918431bfd7f7 Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Wed, 28 May 2025 19:30:56 +0300 Subject: [PATCH 38/74] debugging, WIP sampling utils --- backend/app/experiments/models.py | 7 +- backend/app/experiments/sampling_utils.py | 308 ++++++++++++++++++++++ backend/app/experiments/schemas.py | 6 +- 3 files changed, 319 insertions(+), 2 deletions(-) create mode 100644 backend/app/experiments/sampling_utils.py diff --git a/backend/app/experiments/models.py b/backend/app/experiments/models.py index d8805e6..3eaa8df 100644 --- a/backend/app/experiments/models.py +++ b/backend/app/experiments/models.py @@ -2,6 +2,7 @@ from datetime import datetime, timezone from typing import Optional, Sequence +import numpy as np from sqlalchemy import ( Boolean, DateTime, @@ -506,7 +507,11 @@ async def save_experiment_to_db( mu_init=arm.mu_init, sigma_init=arm.sigma_init, mu=[arm.mu_init] * len_contexts, - covariance=[arm.sigma_init] * len_contexts, + covariance=( + (np.identity(len_contexts) * arm.sigma_init**2).tolist() + if arm.sigma_init + else [[None]] + ), alpha_init=arm.alpha_init, beta_init=arm.beta_init, alpha=arm.alpha_init, diff --git a/backend/app/experiments/sampling_utils.py b/backend/app/experiments/sampling_utils.py new file mode 100644 index 0000000..aea2c6d --- /dev/null +++ b/backend/app/experiments/sampling_utils.py @@ -0,0 +1,308 @@ +from typing import Any, Optional + +import numpy as np +from numpy.random import beta +from scipy.optimize import minimize + +from .schemas import ( + ArmPriors, + ContextLinkFunctions, + ExperimentResponse, + ExperimentsEnum, + Outcome, + RewardLikelihood, +) + + +# ------------- Utilities for sampling and updating arms ---------------- +# --- Sampling functions for Thompson Sampling --- +def _sample_beta_binomial(alphas: np.ndarray, betas: np.ndarray) -> int: + """ + Thompson Sampling with Beta-Binomial distribution. + + Parameters + ---------- + alphas : alpha parameter of Beta distribution for each arm + betas : beta parameter of Beta distribution for each arm + """ + samples = beta(alphas, betas) + return int(samples.argmax()) + + +def _sample_normal( + mus: list[np.ndarray], + covariances: list[np.ndarray], + context: np.ndarray, + link_function: ContextLinkFunctions, +) -> int: + """ + Thompson Sampling with normal prior. + + Parameters + ---------- + mus: mean of Normal distribution for each arm + covariances: covariance matrix of Normal distribution for each arm + context: context vector + link_function: link function for the context + """ + samples = np.array( + [ + np.random.multivariate_normal(mean=mu, cov=cov) + for mu, cov in zip(mus, covariances) + ] + ).reshape(-1, len(context)) + probs = link_function(samples @ context) + return int(probs.argmax()) + + +# --- Arm update functions --- +def _update_arm_beta_binomial( + alpha: float, beta: float, reward: Outcome +) -> tuple[float, float]: + """ + Update the alpha and beta parameters of the Beta distribution. + + Parameters + ---------- + alpha : int + The alpha parameter of the Beta distribution. + beta : int + The beta parameter of the Beta distribution. + reward : Outcome + The reward of the arm. + """ + if reward == Outcome.SUCCESS: + + return alpha + 1, beta + else: + return alpha, beta + 1 + + +def _update_arm_normal( + current_mu: np.ndarray, + current_covariance: np.ndarray, + reward: float, + llhood_sigma: float, + context: Optional[np.ndarray] = None, +) -> tuple[float, np.ndarray]: + """ + Update the mean and standard deviation of the Normal distribution. + + Parameters + ---------- + current_mu : The mean of the Normal distribution. + current_covariance : The covariance of the Normal distribution. + reward : The reward of the arm. + llhood_sigma : The standard deviation of the likelihood. + context : The context vector. + """ + # Likelihood covariance matrix inverse + llhood_covariance_inv = np.eye(len(current_mu)) / llhood_sigma**2 + if context: + llhood_covariance_inv *= context.T @ context + + # Prior covariance matrix inverse + prior_covariance_inv = np.linalg.inv(current_covariance) + + # New covariance + new_covariance = np.linalg.inv(prior_covariance_inv + llhood_covariance_inv) + + # New mean + llhood_term = reward / llhood_sigma**2 + if context: + llhood_term = context.T * llhood_term + new_mu = new_covariance @ ((prior_covariance_inv @ current_mu) + llhood_term) + + return new_mu, new_covariance + + +def _update_arm_laplace( + current_mu: np.ndarray, + current_covariance: np.ndarray, + reward: np.ndarray, + context: np.ndarray, + link_function: ContextLinkFunctions, + reward_likelihood: RewardLikelihood, + prior_type: ArmPriors, +) -> tuple[np.ndarray, np.ndarray]: + """ + Update the mean and covariance using the Laplace approximation. + + Parameters + ---------- + current_mu : The mean of the normal distribution. + current_covariance : The covariance matrix of the normal distribution. + reward : The list of rewards for the arm. + context : The list of contexts for the arm. + link_function : The link function for parameters to rewards. + reward_likelihood : The likelihood function of the reward. + prior_type : The prior type of the arm. + """ + + def objective(theta: np.ndarray) -> float: + """ + Objective function for the Laplace approximation. + + Parameters + ---------- + theta : The parameters of the arm. + """ + # Log prior + log_prior = prior_type(theta, mu=current_mu, covariance=current_covariance) + + # Log likelihood + log_likelihood = reward_likelihood(reward, link_function(context @ theta)) + + return -log_prior - log_likelihood + + result = minimize( + objective, x0=np.zeros_like(current_mu), method="L-BFGS-B", hess="2-point" + ) + new_mu = result.x + covariance = result.hess_inv.todense() # type: ignore + + new_covariance = 0.5 * (covariance + covariance.T) + return new_mu, new_covariance.astype(np.float64) + + +# ------------- Import functions ---------------- +# --- Choose arm function --- +def choose_arm(experiment: ExperimentResponse, context: Optional[list]) -> int: + """ + Choose arm based on posterior using Thompson Sampling. + + Parameters + ---------- + experiment: The experiment data containing priors and rewards for each arm. + context: Optional context vector for the experiment. + """ + # Choose arms with equal probability for Bayesian A/B tests + if experiment.exp_type == ExperimentsEnum.BAYESAB: + index = np.random.choice(len(experiment.arms), size=1) + return int(index[0]) + else: + if experiment.prior_type == ArmPriors.BETA: + if experiment.reward_type != RewardLikelihood.BERNOULLI: + raise ValueError("Beta prior is only supported for Bernoulli rewards.") + alphas = np.array([arm.alpha for arm in experiment.arms]) + betas = np.array([arm.beta for arm in experiment.arms]) + + return _sample_beta_binomial(alphas=alphas, betas=betas) + + elif experiment.prior_type == ArmPriors.NORMAL: + mus = [np.array(arm.mu) for arm in experiment.arms] + covariances = [np.array(arm.covariance) for arm in experiment.arms] + if not context: + context = np.ones_like(mus[0]) + + return _sample_normal( + mus=mus, + covariances=covariances, + context=context, + link_function=( + ContextLinkFunctions.NONE + if experiment.reward_type == RewardLikelihood.NORMAL + else ContextLinkFunctions.LOGISTIC + ), + ) + + +# --- Update arm parameters --- +def update_arm( + experiment: ExperimentResponse, + rewards: list[float], + arm_to_update: Optional[int] = None, + context: Optional[np.ndarray] = None, + treatments: Optional[list[int]] = None, +) -> Any: + """ + Update the arm parameters based on the experiment type and reward. + + Parameters + ---------- + experiment: The experiment data containing arms, prior type and reward + type information. + rewards: The rewards received from the arm. + context: The context vector for the arm. + treatments: The treatments applied to the arm, for a Bayesian A/B test. + """ + + # NB: For Bayesian AB tests, we assume that the update runs + # AFTER all rewards have been observed. + # We hijack the Laplace approximation function to update the + # model parameters as follows: + # 1. current_mu -> [treatment_mu, control_mu, bias_mu = 0] + # 2. current_covariance -> [treatment_sigma, control_sigma, bias_sigma = 1] + # 3. context -> [is_treatment_arm, is_control_arm, 1] + if experiment.exp_type == ExperimentsEnum.BAYESAB: + + assert treatments, "Treatments must be provided for Bayesian A/B tests." + + mus = np.array([arm.mu for arm in experiment.arms] + [0.0]) + covariances = np.diag( + [np.array(arm.covariance).ravel()[0] for arm in experiment.arms] + [1.0] + ) + + context = np.zeros((len(rewards), 3)) + context[:, 0] = np.array(treatments) + context[:, 1] = 1.0 - np.array(treatments) + context[:, 2] = 1.0 + + new_mus, new_covariances = _update_arm_laplace( + current_mu=mus, + current_covariance=covariances, + reward=np.array(rewards), + context=context, + link_function=( + ContextLinkFunctions.NONE + if experiment.reward_type == RewardLikelihood.NORMAL + else ContextLinkFunctions.LOGISTIC + ), + reward_likelihood=experiment.reward_type, + prior_type=experiment.prior_type, + ) + + treatment_mu, control_mu, _ = new_mus + treatment_sigma, control_sigma, _ = np.diag(new_covariances) + return [treatment_mu, control_mu], [[treatment_sigma]], [[control_sigma]] + else: + # Update for MABs and CMABs + assert ( + arm_to_update + ), f"arm_to_update must be provided for {experiment.exp_type} experiments." + + arm = experiment.arms[arm_to_update] + assert arm.alpha and arm.beta, "Arm must have alpha and beta parameters." + + # Beta-binomial priors + if experiment.prior_type == ArmPriors.BETA: + return _update_arm_beta_binomial( + alpha=arm.alpha, beta=arm.beta, reward=Outcome(rewards[0]) + ) + + # Normal priors + elif experiment.prior_type == ArmPriors.NORMAL: + if context is None: + context = np.ones_like(arm.mu) + # Normal likelihood + if experiment.reward_type == RewardLikelihood.NORMAL: + return _update_arm_normal( + current_mu=np.array(arm.mu), + current_covariance=np.array(arm.covariance), + reward=rewards[0], + llhood_sigma=1.0, # TODO: Assuming a fixed likelihood sigma + context=context, + ) + # TODO: only supports Bernoulli likelihood + else: + return _update_arm_laplace( + current_mu=np.array(arm.mu), + current_covariance=np.array(arm.covariance), + reward=np.array(rewards), + context=context, + link_function=ContextLinkFunctions.LOGISTIC, + reward_likelihood=experiment.reward_type, + prior_type=experiment.prior_type, + ) + else: + raise ValueError("Unsupported prior type for arm update.") diff --git a/backend/app/experiments/schemas.py b/backend/app/experiments/schemas.py index 8b7ff52..207309f 100644 --- a/backend/app/experiments/schemas.py +++ b/backend/app/experiments/schemas.py @@ -241,7 +241,7 @@ class ArmResponse(Arm): alpha: Optional[Union[float, None]] beta: Optional[Union[float, None]] mu: Optional[List[Union[float, None]]] - covariance: Optional[List[Union[float, None]]] + covariance: Optional[List[List[Union[float, None]]]] draws: Optional[List[Union[float, None]]] model_config = ConfigDict( from_attributes=True, @@ -482,6 +482,10 @@ def check_prior_reward_type_combo(self) -> Self: raise ValueError( "Beta prior can only be used with binary-valued rewards." ) + if self.exp_type != ExperimentsEnum.MAB: + raise ValueError( + f"Experiments of type {self.exp_type} can only use Gaussian priors." + ) return self From 483d481895a9dfb1d2ee155aea8813bde2d7ca7b Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Wed, 28 May 2025 21:38:29 +0300 Subject: [PATCH 39/74] fix linting --- backend/app/experiments/models.py | 8 ++++---- backend/app/experiments/sampling_utils.py | 14 ++++++++------ 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/backend/app/experiments/models.py b/backend/app/experiments/models.py index 3eaa8df..5f588d7 100644 --- a/backend/app/experiments/models.py +++ b/backend/app/experiments/models.py @@ -111,7 +111,7 @@ def has_contexts(self) -> bool: return self.exp_type == "cmab" @property - def context_list(self) -> list["ContextDB"]: + def context_list(self) -> list["ContextDB"] | list[None]: """Get contexts, returning empty list if not applicable.""" return self.contexts if self.has_contexts else [] @@ -139,9 +139,9 @@ def to_dict(self) -> dict: "arms": [arm.to_dict() for arm in self.arms], "draws": [draw.to_dict() for draw in self.draws], "contexts": ( - [context.to_dict() for context in self.context_list] - if self.has_contexts - else None + [context.to_dict() for context in self.context_list if context] + if len(self.context_list) > 0 + else [] ), } diff --git a/backend/app/experiments/sampling_utils.py b/backend/app/experiments/sampling_utils.py index aea2c6d..32c454f 100644 --- a/backend/app/experiments/sampling_utils.py +++ b/backend/app/experiments/sampling_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any, Optional, Union import numpy as np from numpy.random import beta @@ -108,12 +108,12 @@ def _update_arm_normal( new_covariance = np.linalg.inv(prior_covariance_inv + llhood_covariance_inv) # New mean - llhood_term = reward / llhood_sigma**2 + llhood_term: Union[np.ndarray, float] = reward / llhood_sigma**2 if context: llhood_term = context.T * llhood_term new_mu = new_covariance @ ((prior_covariance_inv @ current_mu) + llhood_term) - return new_mu, new_covariance + return new_mu.tolist(), new_covariance.tolist() def _update_arm_laplace( @@ -162,12 +162,14 @@ def objective(theta: np.ndarray) -> float: covariance = result.hess_inv.todense() # type: ignore new_covariance = 0.5 * (covariance + covariance.T) - return new_mu, new_covariance.astype(np.float64) + return new_mu.tolist(), new_covariance.tolist() # ------------- Import functions ---------------- # --- Choose arm function --- -def choose_arm(experiment: ExperimentResponse, context: Optional[list]) -> int: +def choose_arm( + experiment: ExperimentResponse, context: Optional[Union[list, np.ndarray, None]] +) -> int: """ Choose arm based on posterior using Thompson Sampling. @@ -212,7 +214,7 @@ def update_arm( experiment: ExperimentResponse, rewards: list[float], arm_to_update: Optional[int] = None, - context: Optional[np.ndarray] = None, + context: Optional[Union[list, np.ndarray, None]] = None, treatments: Optional[list[int]] = None, ) -> Any: """ From d9dacdad759e6afda57b83dfed22484f64b2217b Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Thu, 29 May 2025 09:58:03 +0300 Subject: [PATCH 40/74] WIP: draw arm + update arm routers --- backend/app/experiments/dependencies.py | 6 +- backend/app/experiments/models.py | 50 +++++++++++++++++ backend/app/experiments/routers.py | 67 +++++++++++++++++++---- backend/app/experiments/sampling_utils.py | 6 +- backend/app/experiments/schemas.py | 16 ++++++ 5 files changed, 129 insertions(+), 16 deletions(-) diff --git a/backend/app/experiments/dependencies.py b/backend/app/experiments/dependencies.py index b305477..796e015 100644 --- a/backend/app/experiments/dependencies.py +++ b/backend/app/experiments/dependencies.py @@ -1,13 +1,13 @@ from sqlalchemy.ext.asyncio import AsyncSession from .models import ExperimentDB, get_notifications_from_db -from .schemas import ExperimentResponse, NotificationsResponse +from .schemas import ExperimentSample, NotificationsResponse async def experiments_db_to_schema( experiments_db: list[ExperimentDB], asession: AsyncSession, -) -> list[ExperimentResponse]: +) -> list[ExperimentSample]: """ Convert a list of ExperimentDB objects to a list of ExperimentResponse schemas. """ @@ -24,7 +24,7 @@ async def experiments_db_to_schema( ) ] all_experiments.append( - ExperimentResponse.model_validate( + ExperimentSample.model_validate( { **exp_dict, "notifications": [ diff --git a/backend/app/experiments/models.py b/backend/app/experiments/models.py index 5f588d7..2b8725c 100644 --- a/backend/app/experiments/models.py +++ b/backend/app/experiments/models.py @@ -648,3 +648,53 @@ async def delete_experiment_by_id_from_db( await asession.commit() return None + + +# Draw functions +async def get_draw_by_id(draw_id: str, asession: AsyncSession) -> DrawDB | None: + """ + Get a draw by its ID, which should be unique across the system. + """ + statement = select(DrawDB).where(DrawDB.draw_id == draw_id) + result = await asession.execute(statement) + + return result.unique().scalar_one_or_none() + + +async def save_draw_to_db( + draw_id: str, + arm_id: int, + experiment_id: int, + user_id: int | None, + workspace_id: int, + client_id: str, + context: list[float] | None, + asession: AsyncSession, +) -> DrawDB: + """ + Save a draw to the database. + """ + if not user_id: + experiment = await get_experiment_by_id_from_db( + experiment_id=experiment_id, workspace_id=workspace_id, asession=asession + ) + if not experiment: + raise ValueError( + f"Experiment with id {experiment_id} not found for the given ID." + ) + experiment_id = experiment.experiment_id + draw = DrawDB( + draw_id=draw_id, + arm_id=arm_id, + experiment_id=experiment_id, + user_id=user_id, + workspace_id=workspace_id, + client_id=client_id, + draw_datetime_utc=datetime.now(timezone.utc), + context_val=context, + ) + asession.add(draw) + await asession.commit() + await asession.refresh(draw) + + return draw diff --git a/backend/app/experiments/routers.py b/backend/app/experiments/routers.py index 63c120f..c4279e2 100644 --- a/backend/app/experiments/routers.py +++ b/backend/app/experiments/routers.py @@ -25,7 +25,7 @@ ) from .schemas import ( Experiment, - ExperimentResponse, + ExperimentSample, ExperimentsEnum, ) @@ -35,12 +35,12 @@ # --- POST experiments routers --- -@router.post("/", response_model=ExperimentResponse) +@router.post("/", response_model=ExperimentSample) async def create_experiment( experiment: Experiment, user_db: Annotated[UserDB, Depends(require_admin_role)], asession: AsyncSession = Depends(get_async_session), -) -> ExperimentResponse: +) -> ExperimentSample: """ Create a new experiment in the current user's workspace. """ @@ -68,15 +68,15 @@ async def create_experiment( experiment_dict = experiment_db.to_dict() experiment_dict["notifications"] = [n.to_dict() for n in notifications] - return ExperimentResponse.model_validate(experiment_dict) + return ExperimentSample.model_validate(experiment_dict) # -- GET experiment routers --- -@router.get("/", response_model=list[ExperimentResponse]) +@router.get("/", response_model=list[ExperimentSample]) async def get_all_experiments( user_db: Annotated[UserDB, Depends(get_verified_user)], asession: AsyncSession = Depends(get_async_session), -) -> list[ExperimentResponse]: +) -> list[ExperimentSample]: """ Retrieve all experiments for the current user's workspace. """ @@ -100,12 +100,12 @@ async def get_all_experiments( return all_experiments -@router.get("/type/{experiment_type}", response_model=list[ExperimentResponse]) +@router.get("/type/{experiment_type}", response_model=list[ExperimentSample]) async def get_all_experiments_by_type( experiment_type: ExperimentsEnum, user_db: Annotated[UserDB, Depends(get_verified_user)], asession: AsyncSession = Depends(get_async_session), -) -> list[ExperimentResponse]: +) -> list[ExperimentSample]: """ Retrieve all experiments for the current user's workspace. """ @@ -130,12 +130,12 @@ async def get_all_experiments_by_type( return all_experiments -@router.get("/id/{experiment_id}", response_model=ExperimentResponse) +@router.get("/id/{experiment_id}", response_model=ExperimentSample) async def get_experiment_by_id( experiment_id: int, user_db: Annotated[UserDB, Depends(get_verified_user)], asession: AsyncSession = Depends(get_async_session), -) -> ExperimentResponse: +) -> ExperimentSample: """ Retrieve a specific experiment by ID for the current user's workspace. """ @@ -167,6 +167,7 @@ async def get_experiment_by_id( return experiment_dict[0] +# -- DELETE experiment routers --- @router.delete("/type/{experiment_type}", response_model=dict[str, str]) async def delete_experiment_by_type( experiment_type: ExperimentsEnum, @@ -260,3 +261,49 @@ async def delete_experiment_by_id( status_code=500, detail=f"Error: {str(e)}", ) from e + + +# --- Draw and update arms --- +# @router.get("/{experiment_id}/draw", response_model=DrawResponse) +# async def draw_arm( +# experiment_id: int, +# context: Optional[list] = None, +# draw_id: Optional[str] = None, +# workspace_db: WorkspaceDB = Depends(authenticate_workspace_key), +# asession: AsyncSession = Depends(get_async_session), +# ) -> DrawResponse: +# """ +# Draw an arm from the specified experiment. +# """ +# workspace_id = workspace_db.workspace_id + +# experiment = await get_experiment_by_id_from_db( +# workspace_id=workspace_id, experiment_id=experiment_id, asession=asession +# ) +# if experiment is None: +# raise HTTPException( +# status_code=404, detail=f"Experiment with id {experiment_id} not found" +# ) + +# if (experiment.exp_type == ExperimentsEnum.CMAB.value) and (not context): +# raise HTTPException( +# status_code=400, detail="Context is required for CMAB experiments." +# ) + +# # Check for existing draws +# if draw_id is None: +# draw_id = str(uuid4()) + +# existing_draw = await get_draw_by_id(draw_id=draw_id, asession=asession) +# if existing_draw: +# raise HTTPException( +# status_code=400, detail=f"Draw with id {draw_id} already exists." +# ) + +# # Perform the draw +# experiment_data = ExperimentSample.model_validate(experiment) +# chosen_arm = choose_arm(experiment=experiment_data, context=context) +# chosen_arm_id = experiment.arms[chosen_arm].arm_id + +# try: +# draw = await sa diff --git a/backend/app/experiments/sampling_utils.py b/backend/app/experiments/sampling_utils.py index 32c454f..278d0c9 100644 --- a/backend/app/experiments/sampling_utils.py +++ b/backend/app/experiments/sampling_utils.py @@ -7,7 +7,7 @@ from .schemas import ( ArmPriors, ContextLinkFunctions, - ExperimentResponse, + ExperimentSample, ExperimentsEnum, Outcome, RewardLikelihood, @@ -168,7 +168,7 @@ def objective(theta: np.ndarray) -> float: # ------------- Import functions ---------------- # --- Choose arm function --- def choose_arm( - experiment: ExperimentResponse, context: Optional[Union[list, np.ndarray, None]] + experiment: ExperimentSample, context: Optional[Union[list, np.ndarray, None]] ) -> int: """ Choose arm based on posterior using Thompson Sampling. @@ -211,7 +211,7 @@ def choose_arm( # --- Update arm parameters --- def update_arm( - experiment: ExperimentResponse, + experiment: ExperimentSample, rewards: list[float], arm_to_update: Optional[int] = None, context: Optional[Union[list, np.ndarray, None]] = None, diff --git a/backend/app/experiments/schemas.py b/backend/app/experiments/schemas.py index 207309f..26a6c41 100644 --- a/backend/app/experiments/schemas.py +++ b/backend/app/experiments/schemas.py @@ -520,3 +520,19 @@ class ExperimentResponse(ExperimentBase): clients: Optional[list[Client]] = None model_config = ConfigDict(from_attributes=True) + + +class ExperimentSample(ExperimentBase): + """ + Pydantic model for experiments for drawing and updating arms. + """ + + experiment_id: int + n_trials: int + last_trial_datetime_utc: Optional[str] = None + + arms: list[ArmResponse] + contexts: Optional[list[ContextResponse]] = None + clients: Optional[list[Client]] = None + + model_config = ConfigDict(from_attributes=True) From 4ab45c59fb52fa44281ac3566d3da32b9018cf16 Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Thu, 29 May 2025 09:58:55 +0300 Subject: [PATCH 41/74] fresh start migrations --- .../275ff74c0866_add_client_id_to_draws_db.py | 30 - ..._add_tables_for_bayesian_ab_experiments.py | 66 --- .../versions/2d3946caceff_new_start.py | 559 ++++++++++++++++++ ...e1aa8ae_update_tables_with_workspace_id.py | 131 ++++ ...added_first_name_and_last_name_to_users.py | 36 -- .../versions/9f7482ba882f_workspace_model.py | 123 ---- .../ecddd830b464_remove_user_api_key.py | 70 --- .../versions/faf4228e13a3_clean_start.py | 257 -------- ...d_added_sticky_assignments_and_autofail.py | 59 -- 9 files changed, 690 insertions(+), 641 deletions(-) delete mode 100644 backend/migrations/versions/275ff74c0866_add_client_id_to_draws_db.py delete mode 100644 backend/migrations/versions/28adf347e68d_add_tables_for_bayesian_ab_experiments.py create mode 100644 backend/migrations/versions/2d3946caceff_new_start.py create mode 100644 backend/migrations/versions/57173e1aa8ae_update_tables_with_workspace_id.py delete mode 100644 backend/migrations/versions/5c15463fda65_added_first_name_and_last_name_to_users.py delete mode 100644 backend/migrations/versions/9f7482ba882f_workspace_model.py delete mode 100644 backend/migrations/versions/ecddd830b464_remove_user_api_key.py delete mode 100644 backend/migrations/versions/faf4228e13a3_clean_start.py delete mode 100644 backend/migrations/versions/feb042798cad_added_sticky_assignments_and_autofail.py diff --git a/backend/migrations/versions/275ff74c0866_add_client_id_to_draws_db.py b/backend/migrations/versions/275ff74c0866_add_client_id_to_draws_db.py deleted file mode 100644 index 02d31e3..0000000 --- a/backend/migrations/versions/275ff74c0866_add_client_id_to_draws_db.py +++ /dev/null @@ -1,30 +0,0 @@ -"""add client id to draws db - -Revision ID: 275ff74c0866 -Revises: 5c15463fda65 -Create Date: 2025-04-28 20:01:35.705717 - -""" - -from typing import Sequence, Union - -import sqlalchemy as sa -from alembic import op - -# revision identifiers, used by Alembic. -revision: str = "275ff74c0866" -down_revision: Union[str, None] = "5c15463fda65" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.add_column("draws_base", sa.Column("client_id", sa.String(), nullable=True)) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_column("draws_base", "client_id") - # ### end Alembic commands ### diff --git a/backend/migrations/versions/28adf347e68d_add_tables_for_bayesian_ab_experiments.py b/backend/migrations/versions/28adf347e68d_add_tables_for_bayesian_ab_experiments.py deleted file mode 100644 index 94ecd57..0000000 --- a/backend/migrations/versions/28adf347e68d_add_tables_for_bayesian_ab_experiments.py +++ /dev/null @@ -1,66 +0,0 @@ -"""add tables for Bayesian AB experiments - -Revision ID: 28adf347e68d -Revises: feb042798cad -Create Date: 2025-04-27 11:23:26.823140 - -""" - -from typing import Sequence, Union - -import sqlalchemy as sa -from alembic import op - -# revision identifiers, used by Alembic. -revision: str = "28adf347e68d" -down_revision: Union[str, None] = "feb042798cad" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table( - "bayes_ab_experiments", - sa.Column("experiment_id", sa.Integer(), nullable=False), - sa.ForeignKeyConstraint( - ["experiment_id"], ["experiments_base.experiment_id"], ondelete="CASCADE" - ), - sa.PrimaryKeyConstraint("experiment_id"), - ) - op.create_table( - "bayes_ab_arms", - sa.Column("arm_id", sa.Integer(), nullable=False), - sa.Column("mu_init", sa.Float(), nullable=False), - sa.Column("sigma_init", sa.Float(), nullable=False), - sa.Column("mu", sa.Float(), nullable=False), - sa.Column("sigma", sa.Float(), nullable=False), - sa.Column("is_treatment_arm", sa.Boolean(), nullable=False), - sa.ForeignKeyConstraint(["arm_id"], ["arms_base.arm_id"], ondelete="CASCADE"), - sa.PrimaryKeyConstraint("arm_id"), - ) - op.create_table( - "bayes_ab_draws", - sa.Column("draw_id", sa.String(), nullable=False), - sa.ForeignKeyConstraint( - ["draw_id"], ["draws_base.draw_id"], ondelete="CASCADE" - ), - sa.PrimaryKeyConstraint("draw_id"), - ) - op.add_column("mab_arms", sa.Column("alpha_init", sa.Float(), nullable=True)) - op.add_column("mab_arms", sa.Column("beta_init", sa.Float(), nullable=True)) - op.add_column("mab_arms", sa.Column("mu_init", sa.Float(), nullable=True)) - op.add_column("mab_arms", sa.Column("sigma_init", sa.Float(), nullable=True)) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_column("mab_arms", "sigma_init") - op.drop_column("mab_arms", "mu_init") - op.drop_column("mab_arms", "beta_init") - op.drop_column("mab_arms", "alpha_init") - op.drop_table("bayes_ab_draws") - op.drop_table("bayes_ab_arms") - op.drop_table("bayes_ab_experiments") - # ### end Alembic commands ### diff --git a/backend/migrations/versions/2d3946caceff_new_start.py b/backend/migrations/versions/2d3946caceff_new_start.py new file mode 100644 index 0000000..820aa7c --- /dev/null +++ b/backend/migrations/versions/2d3946caceff_new_start.py @@ -0,0 +1,559 @@ +"""new start + +Revision ID: 2d3946caceff +Revises: +Create Date: 2025-05-27 18:39:15.282285 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "2d3946caceff" +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "users", + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("username", sa.String(), nullable=False), + sa.Column("first_name", sa.String(), nullable=False), + sa.Column("last_name", sa.String(), nullable=False), + sa.Column("hashed_password", sa.String(length=96), nullable=False), + sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False), + sa.Column("updated_datetime_utc", sa.DateTime(timezone=True), nullable=False), + sa.Column("access_level", sa.String(), nullable=False), + sa.Column("is_active", sa.Boolean(), nullable=False), + sa.Column("is_verified", sa.Boolean(), nullable=False), + sa.PrimaryKeyConstraint("user_id"), + sa.UniqueConstraint("username"), + ) + op.create_table( + "messages", + sa.Column("message_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("text", sa.String(), nullable=False), + sa.Column("title", sa.String(), nullable=False), + sa.Column("is_unread", sa.Boolean(), nullable=False), + sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False), + sa.Column("message_type", sa.String(length=50), nullable=False), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.user_id"], + ), + sa.PrimaryKeyConstraint("message_id"), + ) + op.create_table( + "workspace", + sa.Column("api_daily_quota", sa.Integer(), nullable=True), + sa.Column("api_key_first_characters", sa.String(length=5), nullable=True), + sa.Column( + "api_key_updated_datetime_utc", sa.DateTime(timezone=True), nullable=True + ), + sa.Column("api_key_rotated_by_user_id", sa.Integer(), nullable=True), + sa.Column("content_quota", sa.Integer(), nullable=True), + sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False), + sa.Column("hashed_api_key", sa.String(length=96), nullable=True), + sa.Column("updated_datetime_utc", sa.DateTime(timezone=True), nullable=False), + sa.Column("workspace_id", sa.Integer(), nullable=False), + sa.Column("workspace_name", sa.String(), nullable=False), + sa.Column("is_default", sa.Boolean(), nullable=False), + sa.ForeignKeyConstraint( + ["api_key_rotated_by_user_id"], + ["users.user_id"], + ), + sa.PrimaryKeyConstraint("workspace_id"), + sa.UniqueConstraint("hashed_api_key"), + sa.UniqueConstraint("workspace_name"), + ) + op.create_table( + "api_key_rotation_history", + sa.Column("rotation_id", sa.Integer(), nullable=False), + sa.Column("workspace_id", sa.Integer(), nullable=False), + sa.Column("rotated_by_user_id", sa.Integer(), nullable=False), + sa.Column("key_first_characters", sa.String(length=5), nullable=False), + sa.Column("rotation_datetime_utc", sa.DateTime(timezone=True), nullable=False), + sa.ForeignKeyConstraint( + ["rotated_by_user_id"], + ["users.user_id"], + ), + sa.ForeignKeyConstraint( + ["workspace_id"], ["workspace.workspace_id"], ondelete="CASCADE" + ), + sa.PrimaryKeyConstraint("rotation_id"), + ) + op.create_table( + "experiments", + sa.Column("experiment_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("workspace_id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(length=150), nullable=False), + sa.Column("description", sa.String(length=500), nullable=False), + sa.Column("is_active", sa.Boolean(), nullable=False), + sa.Column("sticky_assignment", sa.Boolean(), nullable=False), + sa.Column("auto_fail", sa.Boolean(), nullable=False), + sa.Column("auto_fail_value", sa.Integer(), nullable=True), + sa.Column( + "auto_fail_unit", + sa.Enum("DAYS", "HOURS", name="autofailunittype"), + nullable=True, + ), + sa.Column("exp_type", sa.String(length=50), nullable=False), + sa.Column("prior_type", sa.String(length=50), nullable=False), + sa.Column("reward_type", sa.String(length=50), nullable=False), + sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False), + sa.Column("n_trials", sa.Integer(), nullable=False), + sa.Column("last_trial_datetime_utc", sa.DateTime(timezone=True), nullable=True), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.user_id"], + ), + sa.ForeignKeyConstraint( + ["workspace_id"], + ["workspace.workspace_id"], + ), + sa.PrimaryKeyConstraint("experiment_id"), + ) + op.create_table( + "experiments_base", + sa.Column("experiment_id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(length=150), nullable=False), + sa.Column("description", sa.String(length=500), nullable=False), + sa.Column("sticky_assignment", sa.Boolean(), nullable=False), + sa.Column("auto_fail", sa.Boolean(), nullable=False), + sa.Column("auto_fail_value", sa.Integer(), nullable=True), + sa.Column( + "auto_fail_unit", + sa.Enum("DAYS", "HOURS", name="autofailunittype"), + nullable=True, + ), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("workspace_id", sa.Integer(), nullable=False), + sa.Column("is_active", sa.Boolean(), nullable=False), + sa.Column("exp_type", sa.String(length=50), nullable=False), + sa.Column("prior_type", sa.String(length=50), nullable=False), + sa.Column("reward_type", sa.String(length=50), nullable=False), + sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False), + sa.Column("n_trials", sa.Integer(), nullable=False), + sa.Column("last_trial_datetime_utc", sa.DateTime(timezone=True), nullable=True), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.user_id"], + ), + sa.ForeignKeyConstraint( + ["workspace_id"], + ["workspace.workspace_id"], + ), + sa.PrimaryKeyConstraint("experiment_id"), + ) + op.create_table( + "pending_invitations", + sa.Column("invitation_id", sa.Integer(), nullable=False), + sa.Column("email", sa.String(), nullable=False), + sa.Column("workspace_id", sa.Integer(), nullable=False), + sa.Column( + "role", + sa.Enum("ADMIN", "READ_ONLY", name="userroles", native_enum=False), + nullable=False, + ), + sa.Column("inviter_id", sa.Integer(), nullable=False), + sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False), + sa.ForeignKeyConstraint( + ["inviter_id"], + ["users.user_id"], + ), + sa.ForeignKeyConstraint( + ["workspace_id"], ["workspace.workspace_id"], ondelete="CASCADE" + ), + sa.PrimaryKeyConstraint("invitation_id"), + ) + op.create_table( + "user_workspace", + sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False), + sa.Column( + "default_workspace", + sa.Boolean(), + server_default=sa.text("false"), + nullable=False, + ), + sa.Column("updated_datetime_utc", sa.DateTime(timezone=True), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column( + "user_role", + sa.Enum("ADMIN", "READ_ONLY", name="userroles", native_enum=False), + nullable=False, + ), + sa.Column("workspace_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(["user_id"], ["users.user_id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint( + ["workspace_id"], ["workspace.workspace_id"], ondelete="CASCADE" + ), + sa.PrimaryKeyConstraint("user_id", "workspace_id"), + ) + op.create_table( + "arms", + sa.Column("arm_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("workspace_id", sa.Integer(), nullable=False), + sa.Column("experiment_id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(length=150), nullable=False), + sa.Column("description", sa.String(length=500), nullable=False), + sa.Column("n_outcomes", sa.Integer(), nullable=False), + sa.Column("mu_init", sa.Float(), nullable=True), + sa.Column("sigma_init", sa.Float(), nullable=True), + sa.Column("mu", postgresql.ARRAY(sa.Float()), nullable=True), + sa.Column("covariance", postgresql.ARRAY(sa.Float()), nullable=True), + sa.Column("alpha_init", sa.Float(), nullable=True), + sa.Column("beta_init", sa.Float(), nullable=True), + sa.Column("alpha", sa.Float(), nullable=True), + sa.Column("beta", sa.Float(), nullable=True), + sa.ForeignKeyConstraint( + ["experiment_id"], + ["experiments.experiment_id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.user_id"], + ), + sa.ForeignKeyConstraint( + ["workspace_id"], + ["workspace.workspace_id"], + ), + sa.PrimaryKeyConstraint("arm_id"), + ) + op.create_table( + "arms_base", + sa.Column("arm_id", sa.Integer(), nullable=False), + sa.Column("experiment_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(length=150), nullable=False), + sa.Column("description", sa.String(length=500), nullable=False), + sa.Column("arm_type", sa.String(length=50), nullable=False), + sa.Column("n_outcomes", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["experiment_id"], + ["experiments_base.experiment_id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.user_id"], + ), + sa.PrimaryKeyConstraint("arm_id"), + ) + op.create_table( + "bayes_ab_experiments", + sa.Column("experiment_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["experiment_id"], ["experiments_base.experiment_id"], ondelete="CASCADE" + ), + sa.PrimaryKeyConstraint("experiment_id"), + ) + op.create_table( + "clients", + sa.Column("client_id", sa.String(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("experiment_id", sa.Integer(), nullable=False), + sa.Column("workspace_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["experiment_id"], + ["experiments.experiment_id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.user_id"], + ), + sa.ForeignKeyConstraint( + ["workspace_id"], + ["workspace.workspace_id"], + ), + sa.PrimaryKeyConstraint("client_id"), + ) + op.create_table( + "context", + sa.Column("context_id", sa.Integer(), nullable=False), + sa.Column("experiment_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(length=150), nullable=False), + sa.Column("description", sa.String(length=500), nullable=True), + sa.Column("value_type", sa.String(length=50), nullable=False), + sa.ForeignKeyConstraint( + ["experiment_id"], + ["experiments.experiment_id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.user_id"], + ), + sa.PrimaryKeyConstraint("context_id"), + ) + op.create_table( + "contextual_mabs", + sa.Column("experiment_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["experiment_id"], ["experiments_base.experiment_id"], ondelete="CASCADE" + ), + sa.PrimaryKeyConstraint("experiment_id"), + ) + op.create_table( + "event_messages", + sa.Column("message_id", sa.Integer(), nullable=False), + sa.Column("experiment_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["experiment_id"], + ["experiments_base.experiment_id"], + ), + sa.ForeignKeyConstraint( + ["message_id"], ["messages.message_id"], ondelete="CASCADE" + ), + sa.PrimaryKeyConstraint("message_id"), + ) + op.create_table( + "mabs", + sa.Column("experiment_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["experiment_id"], ["experiments_base.experiment_id"], ondelete="CASCADE" + ), + sa.PrimaryKeyConstraint("experiment_id"), + ) + op.create_table( + "notifications", + sa.Column("notification_id", sa.Integer(), nullable=False), + sa.Column("experiment_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("workspace_id", sa.Integer(), nullable=False), + sa.Column( + "notification_type", + sa.Enum( + "DAYS_ELAPSED", + "TRIALS_COMPLETED", + "PERCENTAGE_BETTER", + name="eventtype", + ), + nullable=False, + ), + sa.Column("notification_value", sa.Integer(), nullable=False), + sa.Column("is_active", sa.Boolean(), nullable=False), + sa.ForeignKeyConstraint( + ["experiment_id"], + ["experiments.experiment_id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.user_id"], + ), + sa.ForeignKeyConstraint( + ["workspace_id"], + ["workspace.workspace_id"], + ), + sa.PrimaryKeyConstraint("notification_id"), + ) + op.create_table( + "notifications_db", + sa.Column("notification_id", sa.Integer(), nullable=False), + sa.Column("experiment_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column( + "notification_type", + sa.Enum( + "DAYS_ELAPSED", + "TRIALS_COMPLETED", + "PERCENTAGE_BETTER", + name="eventtype", + ), + nullable=False, + ), + sa.Column("notification_value", sa.Integer(), nullable=False), + sa.Column("is_active", sa.Boolean(), nullable=False), + sa.ForeignKeyConstraint( + ["experiment_id"], + ["experiments_base.experiment_id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.user_id"], + ), + sa.PrimaryKeyConstraint("notification_id"), + ) + op.create_table( + "bayes_ab_arms", + sa.Column("arm_id", sa.Integer(), nullable=False), + sa.Column("mu_init", sa.Float(), nullable=False), + sa.Column("sigma_init", sa.Float(), nullable=False), + sa.Column("mu", sa.Float(), nullable=False), + sa.Column("sigma", sa.Float(), nullable=False), + sa.Column("is_treatment_arm", sa.Boolean(), nullable=False), + sa.ForeignKeyConstraint(["arm_id"], ["arms_base.arm_id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("arm_id"), + ) + op.create_table( + "contexts", + sa.Column("context_id", sa.Integer(), nullable=False), + sa.Column("experiment_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(length=150), nullable=False), + sa.Column("description", sa.String(length=500), nullable=True), + sa.Column("value_type", sa.String(length=50), nullable=False), + sa.ForeignKeyConstraint( + ["experiment_id"], + ["contextual_mabs.experiment_id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.user_id"], + ), + sa.PrimaryKeyConstraint("context_id"), + ) + op.create_table( + "contextual_arms", + sa.Column("arm_id", sa.Integer(), nullable=False), + sa.Column("mu_init", sa.Float(), nullable=False), + sa.Column("sigma_init", sa.Float(), nullable=False), + sa.Column("mu", postgresql.ARRAY(sa.Float()), nullable=False), + sa.Column("covariance", postgresql.ARRAY(sa.Float()), nullable=False), + sa.ForeignKeyConstraint(["arm_id"], ["arms_base.arm_id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("arm_id"), + ) + op.create_table( + "draws", + sa.Column("draw_id", sa.String(), nullable=False), + sa.Column("arm_id", sa.Integer(), nullable=False), + sa.Column("experiment_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("client_id", sa.String(length=36), nullable=False), + sa.Column("draw_datetime_utc", sa.DateTime(timezone=True), nullable=False), + sa.Column("observed_datetime_utc", sa.DateTime(timezone=True), nullable=True), + sa.Column( + "observation_type", + sa.Enum("USER", "AUTO", name="observationtype"), + nullable=True, + ), + sa.Column("reward", sa.Float(), nullable=True), + sa.Column("context_val", postgresql.ARRAY(sa.Float()), nullable=True), + sa.ForeignKeyConstraint( + ["arm_id"], + ["arms.arm_id"], + ), + sa.ForeignKeyConstraint( + ["client_id"], + ["clients.client_id"], + ), + sa.ForeignKeyConstraint( + ["experiment_id"], + ["experiments.experiment_id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.user_id"], + ), + sa.PrimaryKeyConstraint("draw_id"), + ) + op.create_table( + "draws_base", + sa.Column("draw_id", sa.String(), nullable=False), + sa.Column("client_id", sa.String(), nullable=True), + sa.Column("arm_id", sa.Integer(), nullable=False), + sa.Column("experiment_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("draw_datetime_utc", sa.DateTime(timezone=True), nullable=False), + sa.Column("observed_datetime_utc", sa.DateTime(timezone=True), nullable=True), + sa.Column( + "observation_type", + sa.Enum("USER", "AUTO", name="observationtype"), + nullable=True, + ), + sa.Column("draw_type", sa.String(length=50), nullable=False), + sa.Column("reward", sa.Float(), nullable=True), + sa.ForeignKeyConstraint( + ["arm_id"], + ["arms_base.arm_id"], + ), + sa.ForeignKeyConstraint( + ["experiment_id"], + ["experiments_base.experiment_id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.user_id"], + ), + sa.PrimaryKeyConstraint("draw_id"), + ) + op.create_table( + "mab_arms", + sa.Column("arm_id", sa.Integer(), nullable=False), + sa.Column("alpha", sa.Float(), nullable=True), + sa.Column("beta", sa.Float(), nullable=True), + sa.Column("mu", sa.Float(), nullable=True), + sa.Column("sigma", sa.Float(), nullable=True), + sa.Column("alpha_init", sa.Float(), nullable=True), + sa.Column("beta_init", sa.Float(), nullable=True), + sa.Column("mu_init", sa.Float(), nullable=True), + sa.Column("sigma_init", sa.Float(), nullable=True), + sa.ForeignKeyConstraint(["arm_id"], ["arms_base.arm_id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("arm_id"), + ) + op.create_table( + "bayes_ab_draws", + sa.Column("draw_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint( + ["draw_id"], ["draws_base.draw_id"], ondelete="CASCADE" + ), + sa.PrimaryKeyConstraint("draw_id"), + ) + op.create_table( + "contextual_draws", + sa.Column("draw_id", sa.String(), nullable=False), + sa.Column("context_val", postgresql.ARRAY(sa.Float()), nullable=False), + sa.ForeignKeyConstraint( + ["draw_id"], ["draws_base.draw_id"], ondelete="CASCADE" + ), + sa.PrimaryKeyConstraint("draw_id"), + ) + op.create_table( + "mab_draws", + sa.Column("draw_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint( + ["draw_id"], ["draws_base.draw_id"], ondelete="CASCADE" + ), + sa.PrimaryKeyConstraint("draw_id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("mab_draws") + op.drop_table("contextual_draws") + op.drop_table("bayes_ab_draws") + op.drop_table("mab_arms") + op.drop_table("draws_base") + op.drop_table("draws") + op.drop_table("contextual_arms") + op.drop_table("contexts") + op.drop_table("bayes_ab_arms") + op.drop_table("notifications_db") + op.drop_table("notifications") + op.drop_table("mabs") + op.drop_table("event_messages") + op.drop_table("contextual_mabs") + op.drop_table("context") + op.drop_table("clients") + op.drop_table("bayes_ab_experiments") + op.drop_table("arms_base") + op.drop_table("arms") + op.drop_table("user_workspace") + op.drop_table("pending_invitations") + op.drop_table("experiments_base") + op.drop_table("experiments") + op.drop_table("api_key_rotation_history") + op.drop_table("workspace") + op.drop_table("messages") + op.drop_table("users") + # ### end Alembic commands ### diff --git a/backend/migrations/versions/57173e1aa8ae_update_tables_with_workspace_id.py b/backend/migrations/versions/57173e1aa8ae_update_tables_with_workspace_id.py new file mode 100644 index 0000000..19bc51e --- /dev/null +++ b/backend/migrations/versions/57173e1aa8ae_update_tables_with_workspace_id.py @@ -0,0 +1,131 @@ +"""update tables with workspace id + +Revision ID: 57173e1aa8ae +Revises: 2d3946caceff +Create Date: 2025-05-27 21:10:55.499461 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "57173e1aa8ae" +down_revision: Union[str, None] = "2d3946caceff" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("contextual_arms") + op.drop_table("contexts") + op.drop_table("contextual_mabs") + op.drop_table("contextual_draws") + op.add_column("context", sa.Column("workspace_id", sa.Integer(), nullable=False)) + op.create_foreign_key( + None, "context", "workspace", ["workspace_id"], ["workspace_id"] + ) + op.add_column("draws", sa.Column("workspace_id", sa.Integer(), nullable=False)) + op.create_foreign_key( + None, "draws", "workspace", ["workspace_id"], ["workspace_id"] + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, "draws", type_="foreignkey") + op.drop_column("draws", "workspace_id") + op.drop_constraint(None, "context", type_="foreignkey") + op.drop_column("context", "workspace_id") + op.create_table( + "contextual_draws", + sa.Column("draw_id", sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column( + "context_val", + postgresql.ARRAY(sa.DOUBLE_PRECISION(precision=53)), + autoincrement=False, + nullable=False, + ), + sa.ForeignKeyConstraint( + ["draw_id"], + ["draws_base.draw_id"], + name="contextual_draws_draw_id_fkey", + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("draw_id", name="contextual_draws_pkey"), + ) + op.create_table( + "contextual_mabs", + sa.Column("experiment_id", sa.INTEGER(), autoincrement=False, nullable=False), + sa.ForeignKeyConstraint( + ["experiment_id"], + ["experiments_base.experiment_id"], + name="contextual_mabs_experiment_id_fkey", + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("experiment_id", name="contextual_mabs_pkey"), + postgresql_ignore_search_path=False, + ) + op.create_table( + "contexts", + sa.Column("context_id", sa.INTEGER(), autoincrement=True, nullable=False), + sa.Column("experiment_id", sa.INTEGER(), autoincrement=False, nullable=False), + sa.Column("user_id", sa.INTEGER(), autoincrement=False, nullable=False), + sa.Column("name", sa.VARCHAR(length=150), autoincrement=False, nullable=False), + sa.Column( + "description", sa.VARCHAR(length=500), autoincrement=False, nullable=True + ), + sa.Column( + "value_type", sa.VARCHAR(length=50), autoincrement=False, nullable=False + ), + sa.ForeignKeyConstraint( + ["experiment_id"], + ["contextual_mabs.experiment_id"], + name="contexts_experiment_id_fkey", + ), + sa.ForeignKeyConstraint( + ["user_id"], ["users.user_id"], name="contexts_user_id_fkey" + ), + sa.PrimaryKeyConstraint("context_id", name="contexts_pkey"), + ) + op.create_table( + "contextual_arms", + sa.Column("arm_id", sa.INTEGER(), autoincrement=False, nullable=False), + sa.Column( + "mu_init", + sa.DOUBLE_PRECISION(precision=53), + autoincrement=False, + nullable=False, + ), + sa.Column( + "sigma_init", + sa.DOUBLE_PRECISION(precision=53), + autoincrement=False, + nullable=False, + ), + sa.Column( + "mu", + postgresql.ARRAY(sa.DOUBLE_PRECISION(precision=53)), + autoincrement=False, + nullable=False, + ), + sa.Column( + "covariance", + postgresql.ARRAY(sa.DOUBLE_PRECISION(precision=53)), + autoincrement=False, + nullable=False, + ), + sa.ForeignKeyConstraint( + ["arm_id"], + ["arms_base.arm_id"], + name="contextual_arms_arm_id_fkey", + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("arm_id", name="contextual_arms_pkey"), + ) + # ### end Alembic commands ### diff --git a/backend/migrations/versions/5c15463fda65_added_first_name_and_last_name_to_users.py b/backend/migrations/versions/5c15463fda65_added_first_name_and_last_name_to_users.py deleted file mode 100644 index 39b1a4d..0000000 --- a/backend/migrations/versions/5c15463fda65_added_first_name_and_last_name_to_users.py +++ /dev/null @@ -1,36 +0,0 @@ -"""added first name and last name to users - -Revision ID: 5c15463fda65 -Revises: 28adf347e68d -Create Date: 2025-04-26 15:47:23.199751 - -""" - -from typing import Sequence, Union - -import sqlalchemy as sa -from alembic import op - -# revision identifiers, used by Alembic. -revision: str = "5c15463fda65" -down_revision: Union[str, None] = "28adf347e68d" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # Add columns as nullable first - op.add_column("users", sa.Column("first_name", sa.String(), nullable=True)) - op.add_column("users", sa.Column("last_name", sa.String(), nullable=True)) - - # Set default values for existing records - op.execute("UPDATE users SET first_name = '', last_name = ''") - - # Make columns non-nullable - op.alter_column("users", "first_name", nullable=False) - op.alter_column("users", "last_name", nullable=False) - - -def downgrade() -> None: - op.drop_column("users", "last_name") - op.drop_column("users", "first_name") diff --git a/backend/migrations/versions/9f7482ba882f_workspace_model.py b/backend/migrations/versions/9f7482ba882f_workspace_model.py deleted file mode 100644 index 3543211..0000000 --- a/backend/migrations/versions/9f7482ba882f_workspace_model.py +++ /dev/null @@ -1,123 +0,0 @@ -"""Workspace model - -Revision ID: 9f7482ba882f -Revises: 275ff74c0866 -Create Date: 2025-05-04 11:56:03.939578 - -""" - -from typing import Sequence, Union - -import sqlalchemy as sa -from alembic import op - -# revision identifiers, used by Alembic. -revision: str = "9f7482ba882f" -down_revision: Union[str, None] = "275ff74c0866" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table( - "workspace", - sa.Column("api_daily_quota", sa.Integer(), nullable=True), - sa.Column("api_key_first_characters", sa.String(length=5), nullable=True), - sa.Column( - "api_key_updated_datetime_utc", sa.DateTime(timezone=True), nullable=True - ), - sa.Column("api_key_rotated_by_user_id", sa.Integer(), nullable=True), - sa.Column("content_quota", sa.Integer(), nullable=True), - sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False), - sa.Column("hashed_api_key", sa.String(length=96), nullable=True), - sa.Column("updated_datetime_utc", sa.DateTime(timezone=True), nullable=False), - sa.Column("workspace_id", sa.Integer(), nullable=False), - sa.Column("workspace_name", sa.String(), nullable=False), - sa.Column("is_default", sa.Boolean(), nullable=False), - sa.ForeignKeyConstraint( - ["api_key_rotated_by_user_id"], - ["users.user_id"], - ), - sa.PrimaryKeyConstraint("workspace_id"), - sa.UniqueConstraint("hashed_api_key"), - sa.UniqueConstraint("workspace_name"), - ) - op.create_table( - "api_key_rotation_history", - sa.Column("rotation_id", sa.Integer(), nullable=False), - sa.Column("workspace_id", sa.Integer(), nullable=False), - sa.Column("rotated_by_user_id", sa.Integer(), nullable=False), - sa.Column("key_first_characters", sa.String(length=5), nullable=False), - sa.Column("rotation_datetime_utc", sa.DateTime(timezone=True), nullable=False), - sa.ForeignKeyConstraint( - ["rotated_by_user_id"], - ["users.user_id"], - ), - sa.ForeignKeyConstraint( - ["workspace_id"], ["workspace.workspace_id"], ondelete="CASCADE" - ), - sa.PrimaryKeyConstraint("rotation_id"), - ) - op.create_table( - "pending_invitations", - sa.Column("invitation_id", sa.Integer(), nullable=False), - sa.Column("email", sa.String(), nullable=False), - sa.Column("workspace_id", sa.Integer(), nullable=False), - sa.Column( - "role", - sa.Enum("ADMIN", "READ_ONLY", name="userroles", native_enum=False), - nullable=False, - ), - sa.Column("inviter_id", sa.Integer(), nullable=False), - sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False), - sa.ForeignKeyConstraint( - ["inviter_id"], - ["users.user_id"], - ), - sa.ForeignKeyConstraint( - ["workspace_id"], ["workspace.workspace_id"], ondelete="CASCADE" - ), - sa.PrimaryKeyConstraint("invitation_id"), - ) - op.create_table( - "user_workspace", - sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False), - sa.Column( - "default_workspace", - sa.Boolean(), - server_default=sa.text("false"), - nullable=False, - ), - sa.Column("updated_datetime_utc", sa.DateTime(timezone=True), nullable=False), - sa.Column("user_id", sa.Integer(), nullable=False), - sa.Column( - "user_role", - sa.Enum("ADMIN", "READ_ONLY", name="userroles", native_enum=False), - nullable=False, - ), - sa.Column("workspace_id", sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(["user_id"], ["users.user_id"], ondelete="CASCADE"), - sa.ForeignKeyConstraint( - ["workspace_id"], ["workspace.workspace_id"], ondelete="CASCADE" - ), - sa.PrimaryKeyConstraint("user_id", "workspace_id"), - ) - op.add_column( - "experiments_base", sa.Column("workspace_id", sa.Integer(), nullable=False) - ) - op.create_foreign_key( - None, "experiments_base", "workspace", ["workspace_id"], ["workspace_id"] - ) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_constraint(None, "experiments_base", type_="foreignkey") - op.drop_column("experiments_base", "workspace_id") - op.drop_table("user_workspace") - op.drop_table("pending_invitations") - op.drop_table("api_key_rotation_history") - op.drop_table("workspace") - # ### end Alembic commands ### diff --git a/backend/migrations/versions/ecddd830b464_remove_user_api_key.py b/backend/migrations/versions/ecddd830b464_remove_user_api_key.py deleted file mode 100644 index b03b032..0000000 --- a/backend/migrations/versions/ecddd830b464_remove_user_api_key.py +++ /dev/null @@ -1,70 +0,0 @@ -"""Remove User API key - -Revision ID: ecddd830b464 -Revises: 9f7482ba882f -Create Date: 2025-05-21 13:59:22.199884 - -""" - -from typing import Sequence, Union - -import sqlalchemy as sa -from alembic import op -from sqlalchemy.dialects import postgresql - -# revision identifiers, used by Alembic. -revision: str = "ecddd830b464" -down_revision: Union[str, None] = "9f7482ba882f" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_constraint("users_hashed_api_key_key", "users", type_="unique") - op.drop_column("users", "api_daily_quota") - op.drop_column("users", "hashed_api_key") - op.drop_column("users", "api_key_updated_datetime_utc") - op.drop_column("users", "api_key_first_characters") - op.drop_column("users", "experiments_quota") - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.add_column( - "users", - sa.Column( - "experiments_quota", sa.INTEGER(), autoincrement=False, nullable=True - ), - ) - op.add_column( - "users", - sa.Column( - "api_key_first_characters", - sa.VARCHAR(length=5), - autoincrement=False, - nullable=False, - ), - ) - op.add_column( - "users", - sa.Column( - "api_key_updated_datetime_utc", - postgresql.TIMESTAMP(timezone=True), - autoincrement=False, - nullable=False, - ), - ) - op.add_column( - "users", - sa.Column( - "hashed_api_key", sa.VARCHAR(length=96), autoincrement=False, nullable=False - ), - ) - op.add_column( - "users", - sa.Column("api_daily_quota", sa.INTEGER(), autoincrement=False, nullable=True), - ) - op.create_unique_constraint("users_hashed_api_key_key", "users", ["hashed_api_key"]) - # ### end Alembic commands ### diff --git a/backend/migrations/versions/faf4228e13a3_clean_start.py b/backend/migrations/versions/faf4228e13a3_clean_start.py deleted file mode 100644 index 71af813..0000000 --- a/backend/migrations/versions/faf4228e13a3_clean_start.py +++ /dev/null @@ -1,257 +0,0 @@ -"""clean start - -Revision ID: faf4228e13a3 -Revises: -Create Date: 2025-04-17 21:18:03.761219 - -""" - -from typing import Sequence, Union - -import sqlalchemy as sa -from alembic import op -from sqlalchemy.dialects import postgresql - -# revision identifiers, used by Alembic. -revision: str = "faf4228e13a3" -down_revision: Union[str, None] = None -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table( - "users", - sa.Column("user_id", sa.Integer(), nullable=False), - sa.Column("username", sa.String(), nullable=False), - sa.Column("hashed_password", sa.String(length=96), nullable=False), - sa.Column("hashed_api_key", sa.String(length=96), nullable=False), - sa.Column("api_key_first_characters", sa.String(length=5), nullable=False), - sa.Column( - "api_key_updated_datetime_utc", sa.DateTime(timezone=True), nullable=False - ), - sa.Column("experiments_quota", sa.Integer(), nullable=True), - sa.Column("api_daily_quota", sa.Integer(), nullable=True), - sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False), - sa.Column("updated_datetime_utc", sa.DateTime(timezone=True), nullable=False), - sa.Column("access_level", sa.String(), nullable=False), - sa.Column("is_active", sa.Boolean(), nullable=False), - sa.Column("is_verified", sa.Boolean(), nullable=False), - sa.PrimaryKeyConstraint("user_id"), - sa.UniqueConstraint("hashed_api_key"), - sa.UniqueConstraint("username"), - ) - op.create_table( - "experiments_base", - sa.Column("experiment_id", sa.Integer(), nullable=False), - sa.Column("name", sa.String(length=150), nullable=False), - sa.Column("description", sa.String(length=500), nullable=False), - sa.Column("user_id", sa.Integer(), nullable=False), - sa.Column("is_active", sa.Boolean(), nullable=False), - sa.Column("exp_type", sa.String(length=50), nullable=False), - sa.Column("prior_type", sa.String(length=50), nullable=False), - sa.Column("reward_type", sa.String(length=50), nullable=False), - sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False), - sa.Column("n_trials", sa.Integer(), nullable=False), - sa.Column("last_trial_datetime_utc", sa.DateTime(timezone=True), nullable=True), - sa.ForeignKeyConstraint( - ["user_id"], - ["users.user_id"], - ), - sa.PrimaryKeyConstraint("experiment_id"), - ) - op.create_table( - "messages", - sa.Column("message_id", sa.Integer(), nullable=False), - sa.Column("user_id", sa.Integer(), nullable=False), - sa.Column("text", sa.String(), nullable=False), - sa.Column("title", sa.String(), nullable=False), - sa.Column("is_unread", sa.Boolean(), nullable=False), - sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False), - sa.Column("message_type", sa.String(length=50), nullable=False), - sa.ForeignKeyConstraint( - ["user_id"], - ["users.user_id"], - ), - sa.PrimaryKeyConstraint("message_id"), - ) - op.create_table( - "arms_base", - sa.Column("arm_id", sa.Integer(), nullable=False), - sa.Column("experiment_id", sa.Integer(), nullable=False), - sa.Column("user_id", sa.Integer(), nullable=False), - sa.Column("name", sa.String(length=150), nullable=False), - sa.Column("description", sa.String(length=500), nullable=False), - sa.Column("arm_type", sa.String(length=50), nullable=False), - sa.Column("n_outcomes", sa.Integer(), nullable=False), - sa.ForeignKeyConstraint( - ["experiment_id"], - ["experiments_base.experiment_id"], - ), - sa.ForeignKeyConstraint( - ["user_id"], - ["users.user_id"], - ), - sa.PrimaryKeyConstraint("arm_id"), - ) - op.create_table( - "contextual_mabs", - sa.Column("experiment_id", sa.Integer(), nullable=False), - sa.ForeignKeyConstraint( - ["experiment_id"], ["experiments_base.experiment_id"], ondelete="CASCADE" - ), - sa.PrimaryKeyConstraint("experiment_id"), - ) - op.create_table( - "event_messages", - sa.Column("message_id", sa.Integer(), nullable=False), - sa.Column("experiment_id", sa.Integer(), nullable=False), - sa.ForeignKeyConstraint( - ["experiment_id"], - ["experiments_base.experiment_id"], - ), - sa.ForeignKeyConstraint( - ["message_id"], ["messages.message_id"], ondelete="CASCADE" - ), - sa.PrimaryKeyConstraint("message_id"), - ) - op.create_table( - "mabs", - sa.Column("experiment_id", sa.Integer(), nullable=False), - sa.ForeignKeyConstraint( - ["experiment_id"], ["experiments_base.experiment_id"], ondelete="CASCADE" - ), - sa.PrimaryKeyConstraint("experiment_id"), - ) - op.create_table( - "notifications", - sa.Column("notification_id", sa.Integer(), nullable=False), - sa.Column("experiment_id", sa.Integer(), nullable=False), - sa.Column("user_id", sa.Integer(), nullable=False), - sa.Column( - "notification_type", - sa.Enum( - "DAYS_ELAPSED", - "TRIALS_COMPLETED", - "PERCENTAGE_BETTER", - name="eventtype", - ), - nullable=False, - ), - sa.Column("notification_value", sa.Integer(), nullable=False), - sa.Column("is_active", sa.Boolean(), nullable=False), - sa.ForeignKeyConstraint( - ["experiment_id"], - ["experiments_base.experiment_id"], - ), - sa.ForeignKeyConstraint( - ["user_id"], - ["users.user_id"], - ), - sa.PrimaryKeyConstraint("notification_id"), - ) - op.create_table( - "contexts", - sa.Column("context_id", sa.Integer(), nullable=False), - sa.Column("experiment_id", sa.Integer(), nullable=False), - sa.Column("user_id", sa.Integer(), nullable=False), - sa.Column("name", sa.String(length=150), nullable=False), - sa.Column("description", sa.String(length=500), nullable=True), - sa.Column("value_type", sa.String(length=50), nullable=False), - sa.ForeignKeyConstraint( - ["experiment_id"], - ["contextual_mabs.experiment_id"], - ), - sa.ForeignKeyConstraint( - ["user_id"], - ["users.user_id"], - ), - sa.PrimaryKeyConstraint("context_id"), - ) - op.create_table( - "contextual_arms", - sa.Column("arm_id", sa.Integer(), nullable=False), - sa.Column("mu_init", sa.Float(), nullable=False), - sa.Column("sigma_init", sa.Float(), nullable=False), - sa.Column("mu", postgresql.ARRAY(sa.Float()), nullable=False), - sa.Column("covariance", postgresql.ARRAY(sa.Float()), nullable=False), - sa.ForeignKeyConstraint(["arm_id"], ["arms_base.arm_id"], ondelete="CASCADE"), - sa.PrimaryKeyConstraint("arm_id"), - ) - op.create_table( - "draws_base", - sa.Column("draw_id", sa.String(), nullable=False), - sa.Column("arm_id", sa.Integer(), nullable=False), - sa.Column("experiment_id", sa.Integer(), nullable=False), - sa.Column("user_id", sa.Integer(), nullable=False), - sa.Column("draw_datetime_utc", sa.DateTime(timezone=True), nullable=False), - sa.Column("observed_datetime_utc", sa.DateTime(timezone=True), nullable=True), - sa.Column( - "observation_type", - sa.Enum("USER", "AUTO", name="observationtype"), - nullable=True, - ), - sa.Column("draw_type", sa.String(length=50), nullable=False), - sa.Column("reward", sa.Float(), nullable=True), - sa.ForeignKeyConstraint( - ["arm_id"], - ["arms_base.arm_id"], - ), - sa.ForeignKeyConstraint( - ["experiment_id"], - ["experiments_base.experiment_id"], - ), - sa.ForeignKeyConstraint( - ["user_id"], - ["users.user_id"], - ), - sa.PrimaryKeyConstraint("draw_id"), - ) - op.create_table( - "mab_arms", - sa.Column("arm_id", sa.Integer(), nullable=False), - sa.Column("alpha", sa.Float(), nullable=True), - sa.Column("beta", sa.Float(), nullable=True), - sa.Column("mu", sa.Float(), nullable=True), - sa.Column("sigma", sa.Float(), nullable=True), - sa.ForeignKeyConstraint(["arm_id"], ["arms_base.arm_id"], ondelete="CASCADE"), - sa.PrimaryKeyConstraint("arm_id"), - ) - op.create_table( - "contextual_draws", - sa.Column("draw_id", sa.String(), nullable=False), - sa.Column("context_val", postgresql.ARRAY(sa.Float()), nullable=False), - sa.ForeignKeyConstraint( - ["draw_id"], ["draws_base.draw_id"], ondelete="CASCADE" - ), - sa.PrimaryKeyConstraint("draw_id"), - ) - op.create_table( - "mab_draws", - sa.Column("draw_id", sa.String(), nullable=False), - sa.ForeignKeyConstraint( - ["draw_id"], ["draws_base.draw_id"], ondelete="CASCADE" - ), - sa.PrimaryKeyConstraint("draw_id"), - ) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_table("mab_draws") - op.drop_table("contextual_draws") - op.drop_table("mab_arms") - op.drop_table("draws_base") - op.drop_table("contextual_arms") - op.drop_table("contexts") - op.drop_table("notifications") - op.drop_table("mabs") - op.drop_table("event_messages") - op.drop_table("contextual_mabs") - op.drop_table("arms_base") - op.drop_table("messages") - op.drop_table("experiments_base") - op.drop_table("users") - # ### end Alembic commands ### diff --git a/backend/migrations/versions/feb042798cad_added_sticky_assignments_and_autofail.py b/backend/migrations/versions/feb042798cad_added_sticky_assignments_and_autofail.py deleted file mode 100644 index 824c2ba..0000000 --- a/backend/migrations/versions/feb042798cad_added_sticky_assignments_and_autofail.py +++ /dev/null @@ -1,59 +0,0 @@ -"""added sticky assignments and autofail - -Revision ID: feb042798cad -Revises: faf4228e13a3 -Create Date: 2025-04-18 15:11:40.688651 - -""" - -from typing import Sequence, Union - -import sqlalchemy as sa -from alembic import op - -# revision identifiers, used by Alembic. -revision: str = "feb042798cad" -down_revision: Union[str, None] = "faf4228e13a3" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - autofailunittype = sa.Enum( - "DAYS", - "HOURS", - name="autofailunittype", - ) - autofailunittype.create(op.get_bind()) - - op.add_column( - "experiments_base", sa.Column("sticky_assignment", sa.Boolean(), nullable=False) - ) - op.add_column( - "experiments_base", sa.Column("auto_fail", sa.Boolean(), nullable=False) - ) - op.add_column( - "experiments_base", sa.Column("auto_fail_value", sa.Integer(), nullable=True) - ) - op.add_column( - "experiments_base", - sa.Column( - "auto_fail_unit", - autofailunittype, - nullable=True, - ), - ) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_column("experiments_base", "auto_fail_unit") - op.drop_column("experiments_base", "auto_fail_value") - op.drop_column("experiments_base", "auto_fail") - op.drop_column("experiments_base", "sticky_assignment") - - sa.Enum(name="autofailunittype").drop(op.get_bind()) - - # ### end Alembic commands ### From 94373ce41c2418138fdd6fc3440b5568f46ac0b9 Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Fri, 30 May 2025 12:32:08 +0300 Subject: [PATCH 42/74] add choose arm and update arm routers + functions --- backend/app/experiments/dependencies.py | 197 +++++++++++++- backend/app/experiments/models.py | 89 +++++-- backend/app/experiments/routers.py | 245 ++++++++++++++---- backend/app/experiments/sampling_utils.py | 6 +- backend/app/experiments/schemas.py | 51 ++-- ...8f_update_models_for_treatment_arm_and_.py | 64 +++++ 6 files changed, 554 insertions(+), 98 deletions(-) create mode 100644 backend/migrations/versions/4c06937ee88f_update_models_for_treatment_arm_and_.py diff --git a/backend/app/experiments/dependencies.py b/backend/app/experiments/dependencies.py index 796e015..7ee1aa9 100644 --- a/backend/app/experiments/dependencies.py +++ b/backend/app/experiments/dependencies.py @@ -1,7 +1,32 @@ +from datetime import datetime, timezone +from typing import Union + +import numpy as np +from fastapi.exceptions import HTTPException from sqlalchemy.ext.asyncio import AsyncSession -from .models import ExperimentDB, get_notifications_from_db -from .schemas import ExperimentSample, NotificationsResponse +from .models import ( + ArmDB, + DrawDB, + ExperimentDB, + get_draw_by_id, + get_draws_with_rewards_by_experiment_id, + get_experiment_by_id_from_db, + get_notifications_from_db, + save_observation_to_db, +) +from .sampling_utils import update_arm +from .schemas import ( + ArmPriors, + ArmResponse, + DrawResponse, + ExperimentSample, + ExperimentsEnum, + NotificationsResponse, + ObservationType, + Outcome, + RewardLikelihood, +) async def experiments_db_to_schema( @@ -35,3 +60,171 @@ async def experiments_db_to_schema( ) return all_experiments + + +async def validate_experiment_and_draw( + experiment_id: int, draw_id: str, workspace_id: int, asession: AsyncSession +) -> tuple[ExperimentDB, DrawDB]: + """ + Validate the experiment and draw. + """ + experiment = await get_experiment_by_id_from_db( + workspace_id=workspace_id, experiment_id=experiment_id, asession=asession + ) + # Check experiment + if experiment is None: + raise HTTPException( + status_code=404, detail=f"Experiment with id {experiment_id} not found" + ) + + draw = await get_draw_by_id(draw_id=draw_id, asession=asession) + # Check draw + if draw is None: + raise HTTPException(status_code=404, detail=f"Draw with id {draw_id} not found") + if draw.experiment_id != experiment_id: + raise HTTPException( + status_code=404, + detail=( + f"Draw with id {draw_id} does not belong to " + f"experiment with id {experiment_id}" + ), + ) + if draw.reward: + raise HTTPException( + status_code=400, + detail=f"Draw with id {draw_id} has already been updated with a reward.", + ) + + return ExperimentSample.model_validate(experiment), DrawResponse.model_validate( + draw + ) + + +async def format_rewards_for_arm_update( + experiment: ExperimentDB, chosen_arm_id: int, asession: AsyncSession +) -> tuple[list[float], Union[list[float], None], Union[list[float], None]]: + """ + Format the rewards for the arm update. + """ + previous_rewards = await get_draws_with_rewards_by_experiment_id( + experiment_id=experiment.experiment_id, asession=asession + ) + if not previous_rewards: + return [], [], [] + + treatments, contexts = [], [] + if experiment.exp_type != ExperimentsEnum.BAYESAB.value: + rewards = [ + draw.reward for draw in previous_rewards if draw.arm_id == chosen_arm_id + ] + + else: + rewards = [draw.reward for draw in previous_rewards] + treatments = [ + float(experiment.arms[draw.arm_id].is_treatment_arm) + for draw in previous_rewards + ] + + if experiment.exp_type == ExperimentsEnum.CMAB.value: + contexts = [draw.context_val for draw in previous_rewards] + return rewards, contexts, treatments + + +async def update_arm_based_on_outcome( + experiment: ExperimentDB, + draw: DrawDB, + rewards: list[float], + observation_type: ObservationType, + contexts: Union[list[float], list[None]], + treatments: Union[list[float], None], + asession: AsyncSession, +) -> ArmResponse: + """ + Update the arm parameters based on the outcome. + + This is a helper function to allow `auto_fail` job to call + it as well. + """ + update_experiment_metadata(experiment) + + arm = get_arm_from_experiment(experiment, draw.arm_id) + arm.n_outcomes += 1 + + experiment_data = ExperimentSample.model_validate(experiment) + arm = await update_arm_parameters( + arm=arm, + experiment_data=experiment_data, + chosen_arm=np.argwhere([arm.arm_id for arm in experiment.arms] == draw.arm_id)[ + 0 + ][0], + rewards=rewards, + contexts=contexts, + treatments=treatments, + ) + await save_updated_data( + arm=arm, + draw=draw, + reward=rewards[0], + observation_type=observation_type, + asession=asession, + ) + + return ArmResponse.model_validate(arm) + + +def update_experiment_metadata(experiment: ExperimentDB) -> None: + """Update experiment metadata with new trial information""" + experiment.n_trials += 1 + experiment.last_trial_datetime_utc = datetime.now(tz=timezone.utc) + + +def get_arm_from_experiment(experiment: ExperimentDB, arm_id: int) -> ArmDB: + """Get and validate the arm from the experiment""" + arms = [a for a in experiment.arms if a.arm_id == arm_id] + if not arms: + raise HTTPException(status_code=404, detail=f"Arm with id {arm_id} not found") + return arms[0] + + +async def update_arm_parameters( + arm: ArmDB, + experiment_data: ExperimentSample, + chosen_arm: int, + rewards: list[float], + contexts: Union[list[float], list[None]], + treatments: Union[list[float], None], +) -> ArmDB: + """Update the arm parameters based on the reward type and outcome""" + if experiment_data.reward_type == RewardLikelihood.BERNOULLI: + Outcome(rewards[0]) # Check if reward is 0 or 1 + params = update_arm( + experiment=experiment_data, + rewards=rewards, + arm_to_update=chosen_arm, + context=contexts, + treatments=treatments, + ) + if experiment_data.prior_type == ArmPriors.BETA: + arm.alpha, arm.beta = params + elif experiment_data.prior_type == ArmPriors.NORMAL: + arm.mu, arm.covariance = params + else: + raise HTTPException( + status_code=400, + detail="Prior type not supported.", + ) + return arm + + +async def save_updated_data( + arm: ArmDB, + draw: DrawDB, + reward: float, + observation_type: ObservationType, + asession: AsyncSession, +) -> None: + """Save the updated arm and observation data""" + await asession.commit() + await save_observation_to_db( + draw=draw, reward=reward, observation_type=observation_type, asession=asession + ) diff --git a/backend/app/experiments/models.py b/backend/app/experiments/models.py index 2b8725c..fa7a29d 100644 --- a/backend/app/experiments/models.py +++ b/backend/app/experiments/models.py @@ -156,9 +156,6 @@ class ArmDB(Base): # IDs arm_id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False) - user_id: Mapped[int] = mapped_column( - Integer, ForeignKey("users.user_id"), nullable=False - ) workspace_id: Mapped[int] = mapped_column( Integer, ForeignKey("workspace.workspace_id"), nullable=False ) @@ -178,6 +175,7 @@ class ArmDB(Base): covariance: Mapped[Optional[list[float]]] = mapped_column( ARRAY(Float), nullable=True ) + is_treatment_arm: Mapped[bool] = mapped_column(Boolean, nullable=True, default=True) alpha_init: Mapped[Optional[float]] = mapped_column(Float, nullable=True) beta_init: Mapped[Optional[float]] = mapped_column(Float, nullable=True) @@ -234,9 +232,6 @@ class DrawDB(Base): experiment_id: Mapped[int] = mapped_column( Integer, ForeignKey("experiments.experiment_id"), nullable=False ) - user_id: Mapped[int] = mapped_column( - Integer, ForeignKey("users.user_id"), nullable=False - ) workspace_id: Mapped[int] = mapped_column( Integer, ForeignKey("workspace.workspace_id"), nullable=False ) @@ -280,7 +275,6 @@ def to_dict(self) -> dict: "draw_id": self.draw_id, "arm_id": self.arm_id, "experiment_id": self.experiment_id, - "user_id": self.user_id, "client_id": self.client_id, "draw_datetime_utc": self.draw_datetime_utc, "observed_datetime_utc": self.observed_datetime_utc, @@ -303,9 +297,6 @@ class ContextDB(Base): experiment_id: Mapped[int] = mapped_column( Integer, ForeignKey("experiments.experiment_id"), nullable=False ) - user_id: Mapped[int] = mapped_column( - Integer, ForeignKey("users.user_id"), nullable=False - ) workspace_id: Mapped[int] = mapped_column( Integer, ForeignKey("workspace.workspace_id"), nullable=False ) @@ -344,9 +335,6 @@ class ClientDB(Base): client_id: Mapped[str] = mapped_column( String, primary_key=True, default=lambda x: str(uuid.uuid4()) ) - user_id: Mapped[int] = mapped_column( - Integer, ForeignKey("users.user_id"), nullable=False - ) experiment_id: Mapped[int] = mapped_column( Integer, ForeignKey("experiments.experiment_id"), nullable=False ) @@ -368,6 +356,17 @@ class ClientDB(Base): + "ExperimentDB.sticky_assignment == True)", ) + def to_dict(self) -> dict: + """ + Convert the ORM object to a dictionary. + """ + return { + "client_id": self.client_id, + "experiment_id": self.experiment_id, + "workspace_id": self.workspace_id, + "draws": [draw.to_dict() for draw in self.draws], + } + # --- Notifications model --- class NotificationsDB(Base): @@ -497,7 +496,6 @@ async def save_experiment_to_db( arms = [ ArmDB( - user_id=user_id, workspace_id=workspace_id, # description name=arm.name, @@ -516,13 +514,13 @@ async def save_experiment_to_db( beta_init=arm.beta_init, alpha=arm.alpha_init, beta=arm.beta_init, + is_treatment_arm=arm.is_treatment_arm, ) for arm in experiment.arms ] if experiment.contexts and len_contexts > 0: contexts = [ ContextDB( - user_id=user_id, workspace_id=workspace_id, name=context.name, description=context.description, @@ -665,29 +663,18 @@ async def save_draw_to_db( draw_id: str, arm_id: int, experiment_id: int, - user_id: int | None, workspace_id: int, - client_id: str, + client_id: str | None, context: list[float] | None, asession: AsyncSession, ) -> DrawDB: """ Save a draw to the database. """ - if not user_id: - experiment = await get_experiment_by_id_from_db( - experiment_id=experiment_id, workspace_id=workspace_id, asession=asession - ) - if not experiment: - raise ValueError( - f"Experiment with id {experiment_id} not found for the given ID." - ) - experiment_id = experiment.experiment_id draw = DrawDB( draw_id=draw_id, arm_id=arm_id, experiment_id=experiment_id, - user_id=user_id, workspace_id=workspace_id, client_id=client_id, draw_datetime_utc=datetime.now(timezone.utc), @@ -698,3 +685,51 @@ async def save_draw_to_db( await asession.refresh(draw) return draw + + +async def save_observation_to_db( + draw: DrawDB, + reward: float, + observation_type: ObservationType, + asession: AsyncSession, +) -> DrawDB: + """ + Save an observation to the database. + """ + draw.observed_datetime_utc = datetime.now(timezone.utc) + draw.observation_type = observation_type + draw.reward = reward + + await asession.commit() + await asession.refresh(draw) + + return draw + + +async def get_draws_by_experiment_id( + experiment_id: int, asession: AsyncSession +) -> Sequence[DrawDB]: + """ + Get all draws for a given experiment ID. + """ + statement = ( + select(DrawDB) + .where(DrawDB.experiment_id == experiment_id) + .order_by(DrawDB.draw_datetime_utc.desc()) + ) + return (await asession.execute(statement)).unique().scalars().all() + + +async def get_draws_with_rewards_by_experiment_id( + experiment_id: int, asession: AsyncSession +) -> Sequence[DrawDB]: + """ + Get all draws with rewards for a given experiment ID. + """ + statement = ( + select(DrawDB) + .where(DrawDB.experiment_id == experiment_id) + .where(DrawDB.reward.is_not(None)) + .order_by(DrawDB.draw_datetime_utc.desc()) + ) + return (await asession.execute(statement)).unique().scalars().all() diff --git a/backend/app/experiments/routers.py b/backend/app/experiments/routers.py index c4279e2..92f53db 100644 --- a/backend/app/experiments/routers.py +++ b/backend/app/experiments/routers.py @@ -1,10 +1,13 @@ -from typing import Annotated +from typing import Annotated, Optional +from uuid import uuid4 +import numpy as np from fastapi import APIRouter, Depends from fastapi.exceptions import HTTPException from sqlalchemy.ext.asyncio import AsyncSession from ..auth.dependencies import ( + authenticate_workspace_key, get_verified_user, require_admin_role, ) @@ -12,21 +15,37 @@ from ..users.models import UserDB from ..utils import setup_logger from ..workspaces.models import ( + WorkspaceDB, get_user_default_workspace, ) -from .dependencies import experiments_db_to_schema +from .dependencies import ( + experiments_db_to_schema, + format_rewards_for_arm_update, + save_updated_data, + update_arm_parameters, + validate_experiment_and_draw, +) from .models import ( delete_experiment_by_id_from_db, get_all_experiment_types_from_db, get_all_experiments_from_db, + get_draw_by_id, + get_draws_by_experiment_id, get_experiment_by_id_from_db, + save_draw_to_db, save_experiment_to_db, save_notifications_to_db, ) +from .sampling_utils import choose_arm from .schemas import ( + ArmResponse, + ContextInput, + ContextType, + DrawResponse, Experiment, ExperimentSample, ExperimentsEnum, + Outcome, ) router = APIRouter(prefix="/experiment", tags=["Experiments"]) @@ -264,46 +283,182 @@ async def delete_experiment_by_id( # --- Draw and update arms --- -# @router.get("/{experiment_id}/draw", response_model=DrawResponse) -# async def draw_arm( -# experiment_id: int, -# context: Optional[list] = None, -# draw_id: Optional[str] = None, -# workspace_db: WorkspaceDB = Depends(authenticate_workspace_key), -# asession: AsyncSession = Depends(get_async_session), -# ) -> DrawResponse: -# """ -# Draw an arm from the specified experiment. -# """ -# workspace_id = workspace_db.workspace_id - -# experiment = await get_experiment_by_id_from_db( -# workspace_id=workspace_id, experiment_id=experiment_id, asession=asession -# ) -# if experiment is None: -# raise HTTPException( -# status_code=404, detail=f"Experiment with id {experiment_id} not found" -# ) - -# if (experiment.exp_type == ExperimentsEnum.CMAB.value) and (not context): -# raise HTTPException( -# status_code=400, detail="Context is required for CMAB experiments." -# ) - -# # Check for existing draws -# if draw_id is None: -# draw_id = str(uuid4()) - -# existing_draw = await get_draw_by_id(draw_id=draw_id, asession=asession) -# if existing_draw: -# raise HTTPException( -# status_code=400, detail=f"Draw with id {draw_id} already exists." -# ) - -# # Perform the draw -# experiment_data = ExperimentSample.model_validate(experiment) -# chosen_arm = choose_arm(experiment=experiment_data, context=context) -# chosen_arm_id = experiment.arms[chosen_arm].arm_id - -# try: -# draw = await sa +@router.get("/{experiment_id}/draw", response_model=DrawResponse) +async def draw_arm( + experiment_id: int, + contexts: Optional[list[ContextInput]] = None, + draw_id: Optional[str] = None, + workspace_db: WorkspaceDB = Depends(authenticate_workspace_key), + asession: AsyncSession = Depends(get_async_session), +) -> DrawResponse: + """ + Draw an arm from the specified experiment. + """ + workspace_id = workspace_db.workspace_id + + experiment = await get_experiment_by_id_from_db( + workspace_id=workspace_id, experiment_id=experiment_id, asession=asession + ) + if experiment is None: + raise HTTPException( + status_code=404, detail=f"Experiment with id {experiment_id} not found" + ) + + # Check contexts + if (experiment.exp_type == ExperimentsEnum.CMAB.value) and (not contexts): + raise HTTPException( + status_code=400, detail="Context is required for CMAB experiments." + ) + elif (experiment.exp_type == ExperimentsEnum.CMAB.value) and contexts: + if len(contexts) != len(experiment.contexts): + raise HTTPException( + status_code=400, + detail=( + f"Expected {len(experiment.contexts)} contexts" + f" but got {len(contexts)}." + ), + ) + + # Check for existing draws + if draw_id is None: + draw_id = str(uuid4()) + + existing_draw = await get_draw_by_id(draw_id=draw_id, asession=asession) + if existing_draw: + raise HTTPException( + status_code=400, detail=f"Draw with id {draw_id} already exists." + ) + + # -- Perform the draw --- + experiment_data = ExperimentSample.model_validate(experiment) + + # Validate contexts input + if contexts: + sorted_contexts = list(sorted(contexts, key=lambda x: x.context_id)) + try: + for c_input, c_exp in zip( + sorted_contexts, + sorted(experiment_data.contexts, key=lambda x: x.context_id), + ): + if c_exp.value_type == ContextType.BINARY.value: + Outcome(c_input.context_value) + except ValueError as e: + raise HTTPException( + status_code=400, + detail=f"Invalid context value: {e}", + ) from e + + # Choose arm + chosen_arm = choose_arm( + experiment=experiment_data, + context=[c.context_value for c in sorted_contexts] if contexts else None, + ) + chosen_arm_id = experiment.arms[chosen_arm].arm_id + + try: + draw = await save_draw_to_db( + draw_id=draw_id, + arm_id=chosen_arm_id, + experiment_id=experiment_id, + workspace_id=workspace_id, + client_id=None, # TODO: Update for sticky assignment + context=[c.context_value for c in sorted_contexts], + asession=asession, + ) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error saving draw: {str(e)}", + ) from e + + return DrawResponse.model_validate( + draw_id=draw_id, + draw_datetime_utc=draw.draw_datetime_utc, + arm=experiment_data.arms[chosen_arm], + context_val=draw.context_val, + ) + + +@router.put("/{experiment_id}/{draw_id}/{reward}", response_model=ArmResponse) +async def update_arm( + experiment_id: int, + draw_id: str, + reward: float, + workspace_db: WorkspaceDB = Depends(authenticate_workspace_key), + asession: AsyncSession = Depends(get_async_session), +) -> ArmResponse: + """ + Update the arm with the given reward. + """ + + experiment, draw = await validate_experiment_and_draw( + experiment_id=experiment_id, + draw_id=draw_id, + workspace_id=workspace_db.workspace_id, + asession=asession, + ) + + # Get rewards + chosen_arm_index = int( + np.argwhere(np.array([arm.arm_id for arm in experiment.arms]) == draw.arm_id)[ + 0 + ][0], + ) + rewards, contexts, treatments = await format_rewards_for_arm_update( + experiment=experiment, chosen_arm_id=draw.arm_id, asession=asession + ) + rewards = ([reward] + rewards) if rewards else [reward] + contexts = ([draw.context_val] + contexts) if contexts else [draw.context_val] + new_treatment = [float(experiment.arms[chosen_arm_index].is_treatment_arm)] + treatments = (new_treatment + treatments) if treatments else new_treatment + # Update the arm with the given reward + try: + arm = await update_arm_parameters( + arm=experiment.arms[chosen_arm_index], + experiment_data=ExperimentSample.model_validate(experiment), + chosen_arm=chosen_arm_index, + rewards=rewards, + contexts=contexts, + treatments=treatments, + ) + await save_updated_data( + arm=arm, + draw=draw, + reward=reward, + observation_type=experiment.observation_type, + asession=asession, + ) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error updating arm: {str(e)}", + ) from e + + return ArmResponse.model_validate(arm) + + +@router.get("/{experiment_id}/rewards", response_model=list[DrawResponse]) +async def get_rewards( + experiment_id: int, + workspace_db: WorkspaceDB = Depends(authenticate_workspace_key), + asession: AsyncSession = Depends(get_async_session), +) -> list[DrawResponse]: + """ + Retrieve all rewards for the specified experiment. + """ + experiment = await get_experiment_by_id_from_db( + workspace_id=workspace_db.workspace_id, + experiment_id=experiment_id, + asession=asession, + ) + + if not experiment: + raise HTTPException( + status_code=404, detail=f"Experiment with id {experiment_id} not found" + ) + + draws = await get_draws_by_experiment_id( + experiment_id=experiment_id, asession=asession + ) + + return [DrawResponse.model_validate(draw) for draw in draws] diff --git a/backend/app/experiments/sampling_utils.py b/backend/app/experiments/sampling_utils.py index 278d0c9..40bd25a 100644 --- a/backend/app/experiments/sampling_utils.py +++ b/backend/app/experiments/sampling_utils.py @@ -215,7 +215,7 @@ def update_arm( rewards: list[float], arm_to_update: Optional[int] = None, context: Optional[Union[list, np.ndarray, None]] = None, - treatments: Optional[list[int]] = None, + treatments: Optional[list[float]] = None, ) -> Any: """ Update the arm parameters based on the experiment type and reward. @@ -293,7 +293,7 @@ def update_arm( current_covariance=np.array(arm.covariance), reward=rewards[0], llhood_sigma=1.0, # TODO: Assuming a fixed likelihood sigma - context=context, + context=np.array(context), ) # TODO: only supports Bernoulli likelihood else: @@ -301,7 +301,7 @@ def update_arm( current_mu=np.array(arm.mu), current_covariance=np.array(arm.covariance), reward=np.array(rewards), - context=context, + context=np.array(context), link_function=ContextLinkFunctions.LOGISTIC, reward_likelihood=experiment.reward_type, prior_type=experiment.prior_type, diff --git a/backend/app/experiments/schemas.py b/backend/app/experiments/schemas.py index 26a6c41..660ae98 100644 --- a/backend/app/experiments/schemas.py +++ b/backend/app/experiments/schemas.py @@ -212,6 +212,10 @@ class Arm(BaseModel): examples=[None, 1.0], description="Standard deviation parameter for Normal prior", ) + is_treatment_arm: Optional[bool] = Field( + default=True, + description="Whether the arm is a treatment arm or not", + ) @model_validator(mode="after") def check_values(self) -> Self: @@ -242,7 +246,6 @@ class ArmResponse(Arm): beta: Optional[Union[float, None]] mu: Optional[List[Union[float, None]]] covariance: Optional[List[List[Union[float, None]]]] - draws: Optional[List[Union[float, None]]] model_config = ConfigDict( from_attributes=True, ) @@ -301,30 +304,13 @@ class Client(BaseModel): ) -# Draws -class Draw(BaseModel): +class DrawResponse(BaseModel): """ - Pydantic model for a draw. + Pydantic model for a response for draw creation """ model_config = ConfigDict(from_attributes=True) - # Draw info - reward: Optional[float] = Field( - description="Reward observed from the draw", - default=None, - ) - context_val: Optional[list[float]] = Field( - description="Context values associated with the draw", - default=None, - ) - - -class DrawResponse(Draw): - """ - Pydantic model for a response for draw creation - """ - draw_id: str = Field( description="Unique identifier for the draw", examples=["draw_123"], @@ -337,8 +323,18 @@ class DrawResponse(Draw): description="Timestamp of when the reward was observed", default=None, ) + + # Draw info + reward: Optional[float] = Field( + description="Reward observed from the draw", + default=None, + ) + context_val: Optional[list[float]] = Field( + description="Context values associated with the draw", + default=None, + ) arm: ArmResponse - client: Client + client: Optional[Client] = None # Experiments @@ -472,6 +468,19 @@ def check_arm_missing_params(self) -> Self: raise ValueError(f"{val} prior needs {','.join(missing_params)}.") return self + @model_validator(mode="after") + def check_treatment_info(self) -> Self: + """ + Validate that the treatment arm information is set correctly. + """ + arms = self.arms + if self.exp_type == ExperimentsEnum.BAYESAB: + if not any(arm.is_treatment_arm for arm in arms): + raise ValueError("At least one arm must be a treatment arm.") + if all(arm.is_treatment_arm for arm in arms): + raise ValueError("At least one arm must be a control arm.") + return self + @model_validator(mode="after") def check_prior_reward_type_combo(self) -> Self: """ diff --git a/backend/migrations/versions/4c06937ee88f_update_models_for_treatment_arm_and_.py b/backend/migrations/versions/4c06937ee88f_update_models_for_treatment_arm_and_.py new file mode 100644 index 0000000..4a6f7d6 --- /dev/null +++ b/backend/migrations/versions/4c06937ee88f_update_models_for_treatment_arm_and_.py @@ -0,0 +1,64 @@ +"""update models for treatment arm and debugging + +Revision ID: 4c06937ee88f +Revises: 57173e1aa8ae +Create Date: 2025-05-30 12:14:04.889301 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "4c06937ee88f" +down_revision: Union[str, None] = "57173e1aa8ae" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("arms", sa.Column("is_treatment_arm", sa.Boolean(), nullable=True)) + op.drop_constraint("arms_user_id_fkey", "arms", type_="foreignkey") + op.drop_column("arms", "user_id") + op.drop_constraint("clients_user_id_fkey", "clients", type_="foreignkey") + op.drop_column("clients", "user_id") + op.drop_constraint("context_user_id_fkey", "context", type_="foreignkey") + op.drop_column("context", "user_id") + op.drop_constraint("draws_user_id_fkey", "draws", type_="foreignkey") + op.drop_column("draws", "user_id") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "draws", sa.Column("user_id", sa.INTEGER(), autoincrement=False, nullable=False) + ) + op.create_foreign_key( + "draws_user_id_fkey", "draws", "users", ["user_id"], ["user_id"] + ) + op.add_column( + "context", + sa.Column("user_id", sa.INTEGER(), autoincrement=False, nullable=False), + ) + op.create_foreign_key( + "context_user_id_fkey", "context", "users", ["user_id"], ["user_id"] + ) + op.add_column( + "clients", + sa.Column("user_id", sa.INTEGER(), autoincrement=False, nullable=False), + ) + op.create_foreign_key( + "clients_user_id_fkey", "clients", "users", ["user_id"], ["user_id"] + ) + op.add_column( + "arms", sa.Column("user_id", sa.INTEGER(), autoincrement=False, nullable=False) + ) + op.create_foreign_key( + "arms_user_id_fkey", "arms", "users", ["user_id"], ["user_id"] + ) + op.drop_column("arms", "is_treatment_arm") + # ### end Alembic commands ### From 3b273c44eb5136717846ea895b0c6acf24001326 Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Fri, 30 May 2025 13:20:21 +0300 Subject: [PATCH 43/74] fix linting --- backend/app/experiments/dependencies.py | 30 +++++++---- backend/app/experiments/models.py | 4 +- backend/app/experiments/routers.py | 61 +++++++++++++++-------- backend/app/experiments/sampling_utils.py | 8 +-- backend/app/experiments/schemas.py | 1 + 5 files changed, 67 insertions(+), 37 deletions(-) diff --git a/backend/app/experiments/dependencies.py b/backend/app/experiments/dependencies.py index 7ee1aa9..7e302c6 100644 --- a/backend/app/experiments/dependencies.py +++ b/backend/app/experiments/dependencies.py @@ -19,7 +19,6 @@ from .schemas import ( ArmPriors, ArmResponse, - DrawResponse, ExperimentSample, ExperimentsEnum, NotificationsResponse, @@ -95,14 +94,12 @@ async def validate_experiment_and_draw( detail=f"Draw with id {draw_id} has already been updated with a reward.", ) - return ExperimentSample.model_validate(experiment), DrawResponse.model_validate( - draw - ) + return experiment, draw async def format_rewards_for_arm_update( experiment: ExperimentDB, chosen_arm_id: int, asession: AsyncSession -) -> tuple[list[float], Union[list[float], None], Union[list[float], None]]: +) -> tuple[list[float], list[list[float]] | None, list[float] | None]: """ Format the rewards for the arm update. """ @@ -110,14 +107,16 @@ async def format_rewards_for_arm_update( experiment_id=experiment.experiment_id, asession=asession ) if not previous_rewards: - return [], [], [] + return [], None, None + + rewards = [] + treatments = None + contexts = None - treatments, contexts = [], [] if experiment.exp_type != ExperimentsEnum.BAYESAB.value: rewards = [ draw.reward for draw in previous_rewards if draw.arm_id == chosen_arm_id ] - else: rewards = [draw.reward for draw in previous_rewards] treatments = [ @@ -126,7 +125,16 @@ async def format_rewards_for_arm_update( ] if experiment.exp_type == ExperimentsEnum.CMAB.value: - contexts = [draw.context_val for draw in previous_rewards] + contexts = [] + for draw in previous_rewards: + if draw.context_val: + contexts.append(draw.context_val) + else: + raise ValueError( + f"Context value is missing for draw id {draw.draw_id}" + f" in CMAB experiment {draw.experiment_id}." + ) + return rewards, contexts, treatments @@ -135,7 +143,7 @@ async def update_arm_based_on_outcome( draw: DrawDB, rewards: list[float], observation_type: ObservationType, - contexts: Union[list[float], list[None]], + contexts: Union[list[list[float]], None], treatments: Union[list[float], None], asession: AsyncSession, ) -> ArmResponse: @@ -191,7 +199,7 @@ async def update_arm_parameters( experiment_data: ExperimentSample, chosen_arm: int, rewards: list[float], - contexts: Union[list[float], list[None]], + contexts: Union[list[list[float]], None], treatments: Union[list[float], None], ) -> ArmDB: """Update the arm parameters based on the reward type and outcome""" diff --git a/backend/app/experiments/models.py b/backend/app/experiments/models.py index fa7a29d..5342672 100644 --- a/backend/app/experiments/models.py +++ b/backend/app/experiments/models.py @@ -111,9 +111,9 @@ def has_contexts(self) -> bool: return self.exp_type == "cmab" @property - def context_list(self) -> list["ContextDB"] | list[None]: + def context_list(self) -> list["ContextDB"] | list: """Get contexts, returning empty list if not applicable.""" - return self.contexts if self.has_contexts else [] + return self.contexts if self.has_contexts and self.contexts is not None else [] def to_dict(self) -> dict: """ diff --git a/backend/app/experiments/routers.py b/backend/app/experiments/routers.py index 92f53db..855fb94 100644 --- a/backend/app/experiments/routers.py +++ b/backend/app/experiments/routers.py @@ -284,7 +284,7 @@ async def delete_experiment_by_id( # --- Draw and update arms --- @router.get("/{experiment_id}/draw", response_model=DrawResponse) -async def draw_arm( +async def draw_experiment_arm( experiment_id: int, contexts: Optional[list[ContextInput]] = None, draw_id: Optional[str] = None, @@ -310,12 +310,12 @@ async def draw_arm( status_code=400, detail="Context is required for CMAB experiments." ) elif (experiment.exp_type == ExperimentsEnum.CMAB.value) and contexts: - if len(contexts) != len(experiment.contexts): + context_length = 0 if not experiment.contexts else len(experiment.contexts) + if len(contexts) != context_length: raise HTTPException( status_code=400, detail=( - f"Expected {len(experiment.contexts)} contexts" - f" but got {len(contexts)}." + f"Expected {context_length} contexts" f" but got {len(contexts)}." ), ) @@ -336,9 +336,13 @@ async def draw_arm( if contexts: sorted_contexts = list(sorted(contexts, key=lambda x: x.context_id)) try: + exp_contexts = experiment_data.contexts or [] + sorted_exp_contexts = ( + sorted(exp_contexts, key=lambda x: x.context_id) if exp_contexts else [] + ) for c_input, c_exp in zip( sorted_contexts, - sorted(experiment_data.contexts, key=lambda x: x.context_id), + sorted_exp_contexts, ): if c_exp.value_type == ContextType.BINARY.value: Outcome(c_input.context_value) @@ -362,7 +366,7 @@ async def draw_arm( experiment_id=experiment_id, workspace_id=workspace_id, client_id=None, # TODO: Update for sticky assignment - context=[c.context_value for c in sorted_contexts], + context=[c.context_value for c in sorted_contexts] if contexts else None, asession=asession, ) except Exception as e: @@ -371,16 +375,17 @@ async def draw_arm( detail=f"Error saving draw: {str(e)}", ) from e - return DrawResponse.model_validate( - draw_id=draw_id, - draw_datetime_utc=draw.draw_datetime_utc, - arm=experiment_data.arms[chosen_arm], - context_val=draw.context_val, - ) + draw_response_data = { + "draw_id": draw_id, + "draw_datetime_utc": draw.draw_datetime_utc, + "arm": experiment_data.arms[chosen_arm], + "context_val": draw.context_val, + } + return DrawResponse.model_validate(draw_response_data) @router.put("/{experiment_id}/{draw_id}/{reward}", response_model=ArmResponse) -async def update_arm( +async def update_experiment_arm( experiment_id: int, draw_id: str, reward: float, @@ -407,25 +412,39 @@ async def update_arm( rewards, contexts, treatments = await format_rewards_for_arm_update( experiment=experiment, chosen_arm_id=draw.arm_id, asession=asession ) - rewards = ([reward] + rewards) if rewards else [reward] - contexts = ([draw.context_val] + contexts) if contexts else [draw.context_val] + + rewards_list = [reward] if rewards is None else [reward] + rewards + + context_list = None if not draw.context_val else [draw.context_val] + if contexts and context_list: + context_list = context_list + contexts + new_treatment = [float(experiment.arms[chosen_arm_index].is_treatment_arm)] - treatments = (new_treatment + treatments) if treatments else new_treatment + treatments_list = ( + new_treatment if treatments is None else new_treatment + treatments + ) + # Update the arm with the given reward try: + # Get experiment type for observation type + experiment_data = ExperimentSample.model_validate(experiment) + arm = await update_arm_parameters( arm=experiment.arms[chosen_arm_index], - experiment_data=ExperimentSample.model_validate(experiment), + experiment_data=experiment_data, chosen_arm=chosen_arm_index, - rewards=rewards, - contexts=contexts, - treatments=treatments, + rewards=rewards_list, + contexts=context_list, + treatments=treatments_list, ) + + observation_type = experiment_data.observation_type + await save_updated_data( arm=arm, draw=draw, reward=reward, - observation_type=experiment.observation_type, + observation_type=observation_type, asession=asession, ) except Exception as e: diff --git a/backend/app/experiments/sampling_utils.py b/backend/app/experiments/sampling_utils.py index 40bd25a..ea53128 100644 --- a/backend/app/experiments/sampling_utils.py +++ b/backend/app/experiments/sampling_utils.py @@ -194,13 +194,15 @@ def choose_arm( elif experiment.prior_type == ArmPriors.NORMAL: mus = [np.array(arm.mu) for arm in experiment.arms] covariances = [np.array(arm.covariance) for arm in experiment.arms] - if not context: - context = np.ones_like(mus[0]) + + context_array = ( + np.ones_like(mus[0]) if context is None else np.array(context) + ) return _sample_normal( mus=mus, covariances=covariances, - context=context, + context=context_array, link_function=( ContextLinkFunctions.NONE if experiment.reward_type == RewardLikelihood.NORMAL diff --git a/backend/app/experiments/schemas.py b/backend/app/experiments/schemas.py index 660ae98..13c9fde 100644 --- a/backend/app/experiments/schemas.py +++ b/backend/app/experiments/schemas.py @@ -539,6 +539,7 @@ class ExperimentSample(ExperimentBase): experiment_id: int n_trials: int last_trial_datetime_utc: Optional[str] = None + observation_type: ObservationType = ObservationType.USER arms: list[ArmResponse] contexts: Optional[list[ContextResponse]] = None From 156cbe30a7383bc671ca62ddc4d8c3ec44b48623 Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Fri, 30 May 2025 16:05:35 +0300 Subject: [PATCH 44/74] debug routers for beta-binary mab --- backend/app/experiments/models.py | 6 ++-- backend/app/experiments/routers.py | 16 +++++++-- ...57dd34c4cd4_debugging_nullable_variable.py | 34 +++++++++++++++++++ 3 files changed, 52 insertions(+), 4 deletions(-) create mode 100644 backend/migrations/versions/157dd34c4cd4_debugging_nullable_variable.py diff --git a/backend/app/experiments/models.py b/backend/app/experiments/models.py index 5342672..98b881b 100644 --- a/backend/app/experiments/models.py +++ b/backend/app/experiments/models.py @@ -236,7 +236,7 @@ class DrawDB(Base): Integer, ForeignKey("workspace.workspace_id"), nullable=False ) client_id: Mapped[str] = mapped_column( - String(length=36), ForeignKey("clients.client_id"), nullable=False + String(length=36), ForeignKey("clients.client_id"), nullable=True ) # Logging @@ -264,7 +264,7 @@ class DrawDB(Base): "ClientDB", back_populates="draws", lazy="joined", - primaryjoin="and_(DrawDB.client_id==ClientDB.client_id, ExperimentDB.sticky_assignment == True)", # noqa: E501 + primaryjoin="DrawDB.client_id==ClientDB.client_id", # noqa: E501 ) def to_dict(self) -> dict: @@ -700,6 +700,8 @@ async def save_observation_to_db( draw.observation_type = observation_type draw.reward = reward + print(draw.to_dict()) + await asession.commit() await asession.refresh(draw) diff --git a/backend/app/experiments/routers.py b/backend/app/experiments/routers.py index 855fb94..6874ffb 100644 --- a/backend/app/experiments/routers.py +++ b/backend/app/experiments/routers.py @@ -377,7 +377,7 @@ async def draw_experiment_arm( draw_response_data = { "draw_id": draw_id, - "draw_datetime_utc": draw.draw_datetime_utc, + "draw_datetime_utc": str(draw.draw_datetime_utc), "arm": experiment_data.arms[chosen_arm], "context_val": draw.context_val, } @@ -480,4 +480,16 @@ async def get_rewards( experiment_id=experiment_id, asession=asession ) - return [DrawResponse.model_validate(draw) for draw in draws] + return [ + DrawResponse.model_validate( + { + "draw_id": draw.draw_id, + "draw_datetime_utc": str(draw.draw_datetime_utc), + "observed_datetime_utc": str(draw.observed_datetime_utc), + "arm": [arm for arm in experiment.arms if arm.arm_id == draw.arm_id][0], + "reward": draw.reward, + "context_val": draw.context_val, + } + ) + for draw in draws + ] diff --git a/backend/migrations/versions/157dd34c4cd4_debugging_nullable_variable.py b/backend/migrations/versions/157dd34c4cd4_debugging_nullable_variable.py new file mode 100644 index 0000000..a46adf5 --- /dev/null +++ b/backend/migrations/versions/157dd34c4cd4_debugging_nullable_variable.py @@ -0,0 +1,34 @@ +"""debugging nullable variable + +Revision ID: 157dd34c4cd4 +Revises: 4c06937ee88f +Create Date: 2025-05-30 15:43:35.254416 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "157dd34c4cd4" +down_revision: Union[str, None] = "4c06937ee88f" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + "draws", "client_id", existing_type=sa.VARCHAR(length=36), nullable=True + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + "draws", "client_id", existing_type=sa.VARCHAR(length=36), nullable=False + ) + # ### end Alembic commands ### From ec2b7bd5e109aa46ceabc2046ea9f183536ee17e Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Fri, 30 May 2025 17:13:37 +0300 Subject: [PATCH 45/74] debug normal/real-valued experiments --- backend/app/experiments/dependencies.py | 22 ++++++---------------- backend/app/experiments/models.py | 17 ++++++++--------- backend/app/experiments/routers.py | 17 +++++++---------- backend/app/experiments/sampling_utils.py | 5 ++++- 4 files changed, 25 insertions(+), 36 deletions(-) diff --git a/backend/app/experiments/dependencies.py b/backend/app/experiments/dependencies.py index 7e302c6..b83216f 100644 --- a/backend/app/experiments/dependencies.py +++ b/backend/app/experiments/dependencies.py @@ -142,10 +142,8 @@ async def update_arm_based_on_outcome( experiment: ExperimentDB, draw: DrawDB, rewards: list[float], - observation_type: ObservationType, contexts: Union[list[list[float]], None], treatments: Union[list[float], None], - asession: AsyncSession, ) -> ArmResponse: """ Update the arm parameters based on the outcome. @@ -158,24 +156,17 @@ async def update_arm_based_on_outcome( arm = get_arm_from_experiment(experiment, draw.arm_id) arm.n_outcomes += 1 - experiment_data = ExperimentSample.model_validate(experiment) - arm = await update_arm_parameters( + experiment_data = ExperimentSample.model_validate(experiment.to_dict()) + chosen_arm = np.argwhere([a.arm_id == arm.arm_id for a in experiment.arms])[0][0] + + await update_arm_parameters( arm=arm, experiment_data=experiment_data, - chosen_arm=np.argwhere([arm.arm_id for arm in experiment.arms] == draw.arm_id)[ - 0 - ][0], + chosen_arm=chosen_arm, rewards=rewards, contexts=contexts, treatments=treatments, ) - await save_updated_data( - arm=arm, - draw=draw, - reward=rewards[0], - observation_type=observation_type, - asession=asession, - ) return ArmResponse.model_validate(arm) @@ -201,7 +192,7 @@ async def update_arm_parameters( rewards: list[float], contexts: Union[list[list[float]], None], treatments: Union[list[float], None], -) -> ArmDB: +) -> None: """Update the arm parameters based on the reward type and outcome""" if experiment_data.reward_type == RewardLikelihood.BERNOULLI: Outcome(rewards[0]) # Check if reward is 0 or 1 @@ -221,7 +212,6 @@ async def update_arm_parameters( status_code=400, detail="Prior type not supported.", ) - return arm async def save_updated_data( diff --git a/backend/app/experiments/models.py b/backend/app/experiments/models.py index 98b881b..48cb938 100644 --- a/backend/app/experiments/models.py +++ b/backend/app/experiments/models.py @@ -132,10 +132,10 @@ def to_dict(self) -> dict: "exp_type": self.exp_type, "prior_type": self.prior_type, "reward_type": self.reward_type, - "created_datetime_utc": self.created_datetime_utc, + "created_datetime_utc": str(self.created_datetime_utc), "is_active": self.is_active, "n_trials": self.n_trials, - "last_trial_datetime_utc": self.last_trial_datetime_utc, + "last_trial_datetime_utc": str(self.last_trial_datetime_utc), "arms": [arm.to_dict() for arm in self.arms], "draws": [draw.to_dict() for draw in self.draws], "contexts": ( @@ -627,17 +627,18 @@ async def delete_experiment_by_id_from_db( .where(ClientDB.experiment_id == experiment_id) ) - await asession.execute( - delete(ArmDB) - .where(ArmDB.workspace_id == workspace_id) - .where(ArmDB.experiment_id == experiment_id) - ) await asession.execute( delete(DrawDB) .where(DrawDB.workspace_id == workspace_id) .where(DrawDB.experiment_id == experiment_id) ) + await asession.execute( + delete(ArmDB) + .where(ArmDB.workspace_id == workspace_id) + .where(ArmDB.experiment_id == experiment_id) + ) + await asession.execute( delete(ExperimentDB) .where(ExperimentDB.workspace_id == workspace_id) @@ -700,8 +701,6 @@ async def save_observation_to_db( draw.observation_type = observation_type draw.reward = reward - print(draw.to_dict()) - await asession.commit() await asession.refresh(draw) diff --git a/backend/app/experiments/routers.py b/backend/app/experiments/routers.py index 6874ffb..8c0791b 100644 --- a/backend/app/experiments/routers.py +++ b/backend/app/experiments/routers.py @@ -22,7 +22,7 @@ experiments_db_to_schema, format_rewards_for_arm_update, save_updated_data, - update_arm_parameters, + update_arm_based_on_outcome, validate_experiment_and_draw, ) from .models import ( @@ -427,34 +427,31 @@ async def update_experiment_arm( # Update the arm with the given reward try: # Get experiment type for observation type - experiment_data = ExperimentSample.model_validate(experiment) - arm = await update_arm_parameters( - arm=experiment.arms[chosen_arm_index], - experiment_data=experiment_data, - chosen_arm=chosen_arm_index, + await update_arm_based_on_outcome( + experiment=experiment, + draw=draw, rewards=rewards_list, contexts=context_list, treatments=treatments_list, ) - observation_type = experiment_data.observation_type + observation_type = draw.observation_type await save_updated_data( - arm=arm, + arm=experiment.arms[chosen_arm_index], draw=draw, reward=reward, observation_type=observation_type, asession=asession, ) + return ArmResponse.model_validate(experiment.arms[chosen_arm_index]) except Exception as e: raise HTTPException( status_code=500, detail=f"Error updating arm: {str(e)}", ) from e - return ArmResponse.model_validate(arm) - @router.get("/{experiment_id}/rewards", response_model=list[DrawResponse]) async def get_rewards( diff --git a/backend/app/experiments/sampling_utils.py b/backend/app/experiments/sampling_utils.py index ea53128..09f89fd 100644 --- a/backend/app/experiments/sampling_utils.py +++ b/backend/app/experiments/sampling_utils.py @@ -276,16 +276,19 @@ def update_arm( ), f"arm_to_update must be provided for {experiment.exp_type} experiments." arm = experiment.arms[arm_to_update] - assert arm.alpha and arm.beta, "Arm must have alpha and beta parameters." # Beta-binomial priors if experiment.prior_type == ArmPriors.BETA: + assert arm.alpha and arm.beta, "Arm must have alpha and beta parameters." return _update_arm_beta_binomial( alpha=arm.alpha, beta=arm.beta, reward=Outcome(rewards[0]) ) # Normal priors elif experiment.prior_type == ArmPriors.NORMAL: + assert ( + arm.mu and arm.covariance + ), "Arm must have mu and covariance parameters." if context is None: context = np.ones_like(arm.mu) # Normal likelihood From b9c583d78df891ea2271715559478b69426affde Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Tue, 3 Jun 2025 15:19:31 +0300 Subject: [PATCH 46/74] debug Bayes AB beta-binom --- backend/app/experiments/dependencies.py | 55 ++++++++++++++++------- backend/app/experiments/routers.py | 2 +- backend/app/experiments/sampling_utils.py | 27 ++++++++--- backend/app/experiments/schemas.py | 4 +- 4 files changed, 63 insertions(+), 25 deletions(-) diff --git a/backend/app/experiments/dependencies.py b/backend/app/experiments/dependencies.py index b83216f..df3826a 100644 --- a/backend/app/experiments/dependencies.py +++ b/backend/app/experiments/dependencies.py @@ -118,11 +118,16 @@ async def format_rewards_for_arm_update( draw.reward for draw in previous_rewards if draw.arm_id == chosen_arm_id ] else: - rewards = [draw.reward for draw in previous_rewards] - treatments = [ - float(experiment.arms[draw.arm_id].is_treatment_arm) - for draw in previous_rewards - ] + treatments = [] + for draw in previous_rewards: + rewards.append(draw.reward) + treatments.append( + [ + float(arm.is_treatment_arm) + for arm in experiment.arms + if arm.arm_id == draw.arm_id + ][0] + ) if experiment.exp_type == ExperimentsEnum.CMAB.value: contexts = [] @@ -156,12 +161,11 @@ async def update_arm_based_on_outcome( arm = get_arm_from_experiment(experiment, draw.arm_id) arm.n_outcomes += 1 - experiment_data = ExperimentSample.model_validate(experiment.to_dict()) chosen_arm = np.argwhere([a.arm_id == arm.arm_id for a in experiment.arms])[0][0] await update_arm_parameters( arm=arm, - experiment_data=experiment_data, + experiment=experiment, chosen_arm=chosen_arm, rewards=rewards, contexts=contexts, @@ -187,13 +191,14 @@ def get_arm_from_experiment(experiment: ExperimentDB, arm_id: int) -> ArmDB: async def update_arm_parameters( arm: ArmDB, - experiment_data: ExperimentSample, + experiment: ExperimentDB, chosen_arm: int, rewards: list[float], contexts: Union[list[list[float]], None], treatments: Union[list[float], None], ) -> None: """Update the arm parameters based on the reward type and outcome""" + experiment_data = ExperimentSample.model_validate(experiment.to_dict()) if experiment_data.reward_type == RewardLikelihood.BERNOULLI: Outcome(rewards[0]) # Check if reward is 0 or 1 params = update_arm( @@ -203,15 +208,33 @@ async def update_arm_parameters( context=contexts, treatments=treatments, ) - if experiment_data.prior_type == ArmPriors.BETA: - arm.alpha, arm.beta = params - elif experiment_data.prior_type == ArmPriors.NORMAL: - arm.mu, arm.covariance = params + + if experiment_data.exp_type == ExperimentsEnum.BAYESAB: + if experiment_data.prior_type == ArmPriors.NORMAL: + mus, covariances = params + for arm in experiment.arms: + if arm.is_treatment_arm: + arm.mu = [mus[0]] + arm.covariance = covariances[0] + else: + arm.mu = [mus[1]] + arm.covariance = covariances[1] + else: + raise HTTPException( + status_code=400, + detail="Prior type not supported for Bayesian A/B experiments.", + ) else: - raise HTTPException( - status_code=400, - detail="Prior type not supported.", - ) + if experiment_data.prior_type == ArmPriors.BETA: + arm.alpha, arm.beta = params + elif experiment_data.prior_type == ArmPriors.NORMAL: + print("Len params:", len(params)) + arm.mu, arm.covariance = params + else: + raise HTTPException( + status_code=400, + detail="Prior type not supported.", + ) async def save_updated_data( diff --git a/backend/app/experiments/routers.py b/backend/app/experiments/routers.py index 8c0791b..d2c4649 100644 --- a/backend/app/experiments/routers.py +++ b/backend/app/experiments/routers.py @@ -330,7 +330,7 @@ async def draw_experiment_arm( ) # -- Perform the draw --- - experiment_data = ExperimentSample.model_validate(experiment) + experiment_data = ExperimentSample.model_validate(experiment.to_dict()) # Validate contexts input if contexts: diff --git a/backend/app/experiments/sampling_utils.py b/backend/app/experiments/sampling_utils.py index 09f89fd..73a6a21 100644 --- a/backend/app/experiments/sampling_utils.py +++ b/backend/app/experiments/sampling_utils.py @@ -241,13 +241,25 @@ def update_arm( if experiment.exp_type == ExperimentsEnum.BAYESAB: assert treatments, "Treatments must be provided for Bayesian A/B tests." - - mus = np.array([arm.mu for arm in experiment.arms] + [0.0]) + assert [ + arm.mu for arm in experiment.arms + ], "Arms must have mu parameters for Bayesian A/B tests." + assert [ + arm.covariance for arm in experiment.arms + ], "Arms must have covariance parameters for Bayesian A/B tests." + + mus = np.array([arm.mu[0] for arm in experiment.arms if arm.mu] + [0.0]) covariances = np.diag( - [np.array(arm.covariance).ravel()[0] for arm in experiment.arms] + [1.0] + [ + np.array(arm.covariance).ravel()[0] + for arm in experiment.arms + if arm.covariance + ] + + [1.0] + ) + context = ( + np.zeros((len(experiment.arms), 3)) if not context else np.array(context) ) - - context = np.zeros((len(rewards), 3)) context[:, 0] = np.array(treatments) context[:, 1] = 1.0 - np.array(treatments) context[:, 2] = 1.0 @@ -268,7 +280,10 @@ def update_arm( treatment_mu, control_mu, _ = new_mus treatment_sigma, control_sigma, _ = np.diag(new_covariances) - return [treatment_mu, control_mu], [[treatment_sigma]], [[control_sigma]] + return [treatment_mu, control_mu], [ + [[float(treatment_sigma)]], + [[float(control_sigma)]], + ] else: # Update for MABs and CMABs assert ( diff --git a/backend/app/experiments/schemas.py b/backend/app/experiments/schemas.py index 13c9fde..ffd7e6e 100644 --- a/backend/app/experiments/schemas.py +++ b/backend/app/experiments/schemas.py @@ -407,8 +407,8 @@ class Experiment(ExperimentBase): # Relationships arms: list[Arm] notifications: Notifications - contexts: Optional[list[Context]] = [] - clients: Optional[list[Client]] = [] + contexts: Optional[list[Context]] + clients: Optional[list[Client]] @model_validator(mode="after") def auto_fail_unit_and_value_set(self) -> Self: From 7050907107f96f563ab7a07bfa79b88a66913e14 Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Tue, 3 Jun 2025 15:40:32 +0300 Subject: [PATCH 47/74] debug mabs --- backend/app/experiments/dependencies.py | 1 - backend/app/experiments/sampling_utils.py | 11 +++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/backend/app/experiments/dependencies.py b/backend/app/experiments/dependencies.py index df3826a..a6cafa9 100644 --- a/backend/app/experiments/dependencies.py +++ b/backend/app/experiments/dependencies.py @@ -228,7 +228,6 @@ async def update_arm_parameters( if experiment_data.prior_type == ArmPriors.BETA: arm.alpha, arm.beta = params elif experiment_data.prior_type == ArmPriors.NORMAL: - print("Len params:", len(params)) arm.mu, arm.covariance = params else: raise HTTPException( diff --git a/backend/app/experiments/sampling_utils.py b/backend/app/experiments/sampling_utils.py index 73a6a21..cf89620 100644 --- a/backend/app/experiments/sampling_utils.py +++ b/backend/app/experiments/sampling_utils.py @@ -287,10 +287,13 @@ def update_arm( else: # Update for MABs and CMABs assert ( - arm_to_update - ), f"arm_to_update must be provided for {experiment.exp_type} experiments." - - arm = experiment.arms[arm_to_update] + isinstance(arm_to_update, int) and arm_to_update >= 0 + ), "Arm to update must be a non-negative integer." + arm = ( + experiment.arms[arm_to_update] + if arm_to_update is not None + else experiment.arms[0] + ) # Beta-binomial priors if experiment.prior_type == ArmPriors.BETA: From 0f86e7cdfc4c084e83e28d18ced183977c841665 Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Tue, 3 Jun 2025 17:25:42 +0300 Subject: [PATCH 48/74] debug cmabs --- backend/app/experiments/dependencies.py | 4 +++- backend/app/experiments/routers.py | 8 +++++++- backend/app/experiments/sampling_utils.py | 20 +++++++------------- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/backend/app/experiments/dependencies.py b/backend/app/experiments/dependencies.py index a6cafa9..77435bc 100644 --- a/backend/app/experiments/dependencies.py +++ b/backend/app/experiments/dependencies.py @@ -161,7 +161,9 @@ async def update_arm_based_on_outcome( arm = get_arm_from_experiment(experiment, draw.arm_id) arm.n_outcomes += 1 - chosen_arm = np.argwhere([a.arm_id == arm.arm_id for a in experiment.arms])[0][0] + chosen_arm = int( + np.argwhere([a.arm_id == arm.arm_id for a in experiment.arms])[0][0] + ) await update_arm_parameters( arm=arm, diff --git a/backend/app/experiments/routers.py b/backend/app/experiments/routers.py index d2c4649..7a03dc7 100644 --- a/backend/app/experiments/routers.py +++ b/backend/app/experiments/routers.py @@ -283,7 +283,7 @@ async def delete_experiment_by_id( # --- Draw and update arms --- -@router.get("/{experiment_id}/draw", response_model=DrawResponse) +@router.put("/{experiment_id}/draw", response_model=DrawResponse) async def draw_experiment_arm( experiment_id: int, contexts: Optional[list[ContextInput]] = None, @@ -340,6 +340,12 @@ async def draw_experiment_arm( sorted_exp_contexts = ( sorted(exp_contexts, key=lambda x: x.context_id) if exp_contexts else [] ) + if [c1.context_id for c1 in sorted_contexts] != [ + c2.context_id for c2 in sorted_exp_contexts + ]: + raise ValueError( + "Provided contexts do not match the experiment's expected contexts." + ) for c_input, c_exp in zip( sorted_contexts, sorted_exp_contexts, diff --git a/backend/app/experiments/sampling_utils.py b/backend/app/experiments/sampling_utils.py index cf89620..31e518d 100644 --- a/backend/app/experiments/sampling_utils.py +++ b/backend/app/experiments/sampling_utils.py @@ -98,7 +98,7 @@ def _update_arm_normal( """ # Likelihood covariance matrix inverse llhood_covariance_inv = np.eye(len(current_mu)) / llhood_sigma**2 - if context: + if context is not None: llhood_covariance_inv *= context.T @ context # Prior covariance matrix inverse @@ -109,10 +109,10 @@ def _update_arm_normal( # New mean llhood_term: Union[np.ndarray, float] = reward / llhood_sigma**2 - if context: - llhood_term = context.T * llhood_term + print("llhood_term", llhood_term) + if context is not None: + llhood_term = (context * llhood_term).squeeze() new_mu = new_covariance @ ((prior_covariance_inv @ current_mu) + llhood_term) - return new_mu.tolist(), new_covariance.tolist() @@ -286,14 +286,8 @@ def update_arm( ] else: # Update for MABs and CMABs - assert ( - isinstance(arm_to_update, int) and arm_to_update >= 0 - ), "Arm to update must be a non-negative integer." - arm = ( - experiment.arms[arm_to_update] - if arm_to_update is not None - else experiment.arms[0] - ) + assert arm_to_update is not None, "Arm to update must be provided." + arm = experiment.arms[arm_to_update] # Beta-binomial priors if experiment.prior_type == ArmPriors.BETA: @@ -316,7 +310,7 @@ def update_arm( current_covariance=np.array(arm.covariance), reward=rewards[0], llhood_sigma=1.0, # TODO: Assuming a fixed likelihood sigma - context=np.array(context), + context=np.array(context[0]), ) # TODO: only supports Bernoulli likelihood else: From fb2904b4372b1246430f4ea44a3063176db9ebbf Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Tue, 3 Jun 2025 18:04:52 +0300 Subject: [PATCH 49/74] update autofail --- backend/app/__init__.py | 5 +- backend/app/experiments/dependencies.py | 18 ++- backend/app/experiments/routers.py | 20 +-- backend/jobs/auto_fail.py | 188 +++--------------------- 4 files changed, 46 insertions(+), 185 deletions(-) diff --git a/backend/app/__init__.py b/backend/app/__init__.py index 67a6f7c..88e9dfd 100644 --- a/backend/app/__init__.py +++ b/backend/app/__init__.py @@ -5,7 +5,7 @@ from fastapi.middleware.cors import CORSMiddleware from redis import asyncio as aioredis -from . import auth, bayes_ab, mab, messages +from . import auth, messages from .config import BACKEND_ROOT_PATH, DOMAIN, REDIS_HOST from .experiments.routers import router as experiments_router from .users.routers import ( @@ -58,9 +58,6 @@ def create_app() -> FastAPI: ) app.include_router(experiments_router) - app.include_router(mab.router) - # app.include_router(contextual_mab.router) - app.include_router(bayes_ab.router) app.include_router(auth.router) app.include_router(users_router) app.include_router(messages.router) diff --git a/backend/app/experiments/dependencies.py b/backend/app/experiments/dependencies.py index 77435bc..f9c7380 100644 --- a/backend/app/experiments/dependencies.py +++ b/backend/app/experiments/dependencies.py @@ -98,7 +98,7 @@ async def validate_experiment_and_draw( async def format_rewards_for_arm_update( - experiment: ExperimentDB, chosen_arm_id: int, asession: AsyncSession + experiment: ExperimentDB, chosen_arm_id: int, reward: float, asession: AsyncSession ) -> tuple[list[float], list[list[float]] | None, list[float] | None]: """ Format the rewards for the arm update. @@ -140,7 +140,21 @@ async def format_rewards_for_arm_update( f" in CMAB experiment {draw.experiment_id}." ) - return rewards, contexts, treatments + rewards_list = [reward] if rewards is None else [reward] + rewards + + context_list = None if not draw.context_val else [draw.context_val] + if contexts and context_list: + context_list = context_list + contexts + + chosen_arm_index = int( + np.argwhere([a.arm_id == chosen_arm_id for a in experiment.arms])[0][0] + ) + new_treatment = [float(experiment.arms[chosen_arm_index].is_treatment_arm)] + treatments_list = ( + new_treatment if treatments is None else new_treatment + treatments + ) + + return rewards_list, context_list, treatments_list async def update_arm_based_on_outcome( diff --git a/backend/app/experiments/routers.py b/backend/app/experiments/routers.py index 7a03dc7..27f885d 100644 --- a/backend/app/experiments/routers.py +++ b/backend/app/experiments/routers.py @@ -415,25 +415,15 @@ async def update_experiment_arm( 0 ][0], ) - rewards, contexts, treatments = await format_rewards_for_arm_update( - experiment=experiment, chosen_arm_id=draw.arm_id, asession=asession - ) - - rewards_list = [reward] if rewards is None else [reward] + rewards - - context_list = None if not draw.context_val else [draw.context_val] - if contexts and context_list: - context_list = context_list + contexts - - new_treatment = [float(experiment.arms[chosen_arm_index].is_treatment_arm)] - treatments_list = ( - new_treatment if treatments is None else new_treatment + treatments + rewards_list, context_list, treatments_list = await format_rewards_for_arm_update( + experiment=experiment, + chosen_arm_id=draw.arm_id, + reward=reward, + asession=asession, ) # Update the arm with the given reward try: - # Get experiment type for observation type - await update_arm_based_on_outcome( experiment=experiment, draw=draw, diff --git a/backend/jobs/auto_fail.py b/backend/jobs/auto_fail.py index 3665831..2323a56 100644 --- a/backend/jobs/auto_fail.py +++ b/backend/jobs/auto_fail.py @@ -15,21 +15,16 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from app.bayes_ab.models import BayesianABDB, BayesianABDrawDB -from app.bayes_ab.observation import ( - update_based_on_outcome as bayes_ab_update_based_on_outcome, -) -from app.contextual_mab.models import ContextualBanditDB, ContextualDrawDB -from app.contextual_mab.observation import ( - update_based_on_outcome as cmab_update_based_on_outcome, -) from app.database import get_async_session -from app.mab.models import MABDrawDB, MultiArmedBanditDB -from app.mab.observation import update_based_on_outcome as mab_update_based_on_outcome -from app.schemas import ObservationType +from app.experiments.dependencies import ( + format_rewards_for_arm_update, + update_arm_based_on_outcome, +) +from app.experiments.models import DrawDB, ExperimentDB +from app.experiments.schemas import ObservationType -async def auto_fail_mab(asession: AsyncSession) -> int: +async def auto_fail_experiment(asession: AsyncSession) -> int: """ Auto fail experiments draws that have not been updated in a certain amount of time. @@ -43,9 +38,7 @@ async def auto_fail_mab(asession: AsyncSession) -> int: now = datetime.now(tz=timezone.utc) # Fetch all required experiments data in one query - experiment_query = select(MultiArmedBanditDB).where( - MultiArmedBanditDB.auto_fail.is_(True) - ) + experiment_query = select(ExperimentDB).where(ExperimentDB.auto_fail.is_(True)) experiments_result = (await asession.execute(experiment_query)).unique() experiments = experiments_result.scalars().all() for experiment in experiments: @@ -58,15 +51,15 @@ async def auto_fail_mab(asession: AsyncSession) -> int: cutoff_datetime = now - timedelta(hours=hours_threshold) draws_query = ( - select(MABDrawDB) + select(DrawDB) .join( - MultiArmedBanditDB, - MABDrawDB.experiment_id == MultiArmedBanditDB.experiment_id, + ExperimentDB, + DrawDB.experiment_id == ExperimentDB.experiment_id, ) .where( - MABDrawDB.experiment_id == experiment.experiment_id, - MABDrawDB.observation_type.is_(None), - MABDrawDB.draw_datetime_utc <= cutoff_datetime, + DrawDB.experiment_id == experiment.experiment_id, + DrawDB.observation_type.is_(None), + DrawDB.draw_datetime_utc <= cutoff_datetime, ) .limit(100) ) # Process in smaller batches @@ -83,146 +76,17 @@ async def auto_fail_mab(asession: AsyncSession) -> int: for draw in draws_batch: draw.observation_type = ObservationType.AUTO - await mab_update_based_on_outcome( - experiment, - draw, - 0.0, - asession, - ObservationType.AUTO, - ) - - total_failed += 1 - - await asession.commit() - offset += len(draws_batch) - - return total_failed - - -async def auto_fail_bayes_ab(asession: AsyncSession) -> int: - """ - Auto fail experiments draws that have not been updated in a certain amount of time. - - """ - total_failed = 0 - now = datetime.now(tz=timezone.utc) - - # Fetch all required experiments data in one query - experiment_query = select(BayesianABDB).where(BayesianABDB.auto_fail.is_(True)) - experiments_result = (await asession.execute(experiment_query)).unique() - experiments = experiments_result.scalars().all() - for experiment in experiments: - hours_threshold = ( - experiment.auto_fail_value * 24 - if experiment.auto_fail_unit == "days" - else experiment.auto_fail_value - ) - - cutoff_datetime = now - timedelta(hours=hours_threshold) - - draws_query = ( - select(BayesianABDrawDB) - .join( - BayesianABDB, - BayesianABDrawDB.experiment_id == BayesianABDB.experiment_id, - ) - .where( - BayesianABDrawDB.experiment_id == experiment.experiment_id, - BayesianABDrawDB.observation_type.is_(None), - BayesianABDrawDB.draw_datetime_utc <= cutoff_datetime, - ) - .limit(100) - ) # Process in smaller batches - - # Paginate through results if there are many draws to avoid memory issues - offset = 0 - while True: - batch_query = draws_query.offset(offset) - draws_result = (await asession.execute(batch_query)).unique() - draws_batch = draws_result.scalars().all() - if not draws_batch: - break - - for draw in draws_batch: - draw.observation_type = ObservationType.AUTO - - await bayes_ab_update_based_on_outcome( - experiment, - draw, - 0.0, - asession, - ObservationType.AUTO, + rewards_list, context_list, treatments_list = ( + await format_rewards_for_arm_update( + experiment, draw.arm_id, 0.0, asession + ) ) - - total_failed += 1 - - await asession.commit() - offset += len(draws_batch) - - return total_failed - - -async def auto_fail_cmab(asession: AsyncSession) -> int: - """ - Auto fail experiments draws that have not been updated in a certain amount of time. - - Args: - asession: SQLAlchemy async session - - Returns: - int: Number of draws automatically failed - """ - total_failed = 0 - now = datetime.now(tz=timezone.utc) - - # Fetch all required experiments data in one query - experiment_query = select(ContextualBanditDB).where( - ContextualBanditDB.auto_fail.is_(True) - ) - experiments_result = (await asession.execute(experiment_query)).unique() - experiments = experiments_result.scalars().all() - for experiment in experiments: - hours_threshold = ( - experiment.auto_fail_value * 24 - if experiment.auto_fail_unit == "days" - else experiment.auto_fail_value - ) - - cutoff_datetime = now - timedelta(hours=hours_threshold) - - draws_query = ( - select(ContextualDrawDB) - .join( - ContextualBanditDB, - ContextualDrawDB.experiment_id == ContextualBanditDB.experiment_id, - ) - .where( - ContextualDrawDB.experiment_id == experiment.experiment_id, - ContextualDrawDB.observation_type.is_(None), - ContextualDrawDB.draw_datetime_utc <= cutoff_datetime, - ) - .limit(100) - ) # Process in smaller batches - - # Paginate through results if there are many draws to avoid memory issues - offset = 0 - while True: - batch_query = draws_query.offset(offset) - draws_result = (await asession.execute(batch_query)).unique() - draws_batch = draws_result.scalars().all() - - if not draws_batch: - break - - for draw in draws_batch: - draw.observation_type = ObservationType.AUTO - - await cmab_update_based_on_outcome( + await update_arm_based_on_outcome( experiment, draw, - 0.0, - asession, - ObservationType.AUTO, + rewards_list, + context_list, + treatments_list, ) total_failed += 1 @@ -238,12 +102,8 @@ async def main() -> None: Main function to process notifications """ async for asession in get_async_session(): - failed_count = await auto_fail_mab(asession) - print(f"Auto-failed MABs: {failed_count} draws") - failed_count = await auto_fail_cmab(asession) - print(f"Auto-failed CMABs: {failed_count} draws") - failed_count = await auto_fail_bayes_ab(asession) - print(f"Auto-failed Bayes ABs: {failed_count} draws") + failed_count = await auto_fail_experiment(asession) + print(f"Auto-failed experiments: {failed_count} draws") break From 36a0db56415ecb721f95f378e83a86575d81b0f5 Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Tue, 3 Jun 2025 18:05:59 +0300 Subject: [PATCH 50/74] delete old routers and migrations --- backend/app/bayes_ab/__init__.py | 1 - backend/app/bayes_ab/models.py | 433 --------------- backend/app/bayes_ab/observation.py | 75 --- backend/app/bayes_ab/routers.py | 524 ------------------ backend/app/bayes_ab/sampling_utils.py | 126 ----- backend/app/bayes_ab/schemas.py | 145 ----- backend/app/contextual_mab/models.py | 483 ---------------- backend/app/contextual_mab/observation.py | 126 ----- backend/app/contextual_mab/routers.py | 395 ------------- backend/app/contextual_mab/schemas.py | 268 --------- backend/app/mab/__init__.py | 1 - backend/app/mab/models.py | 419 -------------- backend/app/mab/observation.py | 94 ---- backend/app/mab/routers.py | 357 ------------ backend/app/mab/sampling_utils.py | 138 ----- backend/app/mab/schemas.py | 262 --------- ...57dd34c4cd4_debugging_nullable_variable.py | 34 -- ...8f_update_models_for_treatment_arm_and_.py | 64 --- ...e1aa8ae_update_tables_with_workspace_id.py | 131 ----- ...w_start.py => 6101ba814d91_fresh_start.py} | 145 +---- .../versions/9f7482ba882f_workspace_model.py | 123 ---- .../ecddd830b464_remove_user_api_key.py | 70 --- 22 files changed, 12 insertions(+), 4402 deletions(-) delete mode 100644 backend/app/bayes_ab/__init__.py delete mode 100644 backend/app/bayes_ab/models.py delete mode 100644 backend/app/bayes_ab/observation.py delete mode 100644 backend/app/bayes_ab/routers.py delete mode 100644 backend/app/bayes_ab/sampling_utils.py delete mode 100644 backend/app/bayes_ab/schemas.py delete mode 100644 backend/app/contextual_mab/models.py delete mode 100644 backend/app/contextual_mab/observation.py delete mode 100644 backend/app/contextual_mab/routers.py delete mode 100644 backend/app/contextual_mab/schemas.py delete mode 100644 backend/app/mab/__init__.py delete mode 100644 backend/app/mab/models.py delete mode 100644 backend/app/mab/observation.py delete mode 100644 backend/app/mab/routers.py delete mode 100644 backend/app/mab/sampling_utils.py delete mode 100644 backend/app/mab/schemas.py delete mode 100644 backend/migrations/versions/157dd34c4cd4_debugging_nullable_variable.py delete mode 100644 backend/migrations/versions/4c06937ee88f_update_models_for_treatment_arm_and_.py delete mode 100644 backend/migrations/versions/57173e1aa8ae_update_tables_with_workspace_id.py rename backend/migrations/versions/{2d3946caceff_new_start.py => 6101ba814d91_fresh_start.py} (76%) delete mode 100644 backend/migrations/versions/9f7482ba882f_workspace_model.py delete mode 100644 backend/migrations/versions/ecddd830b464_remove_user_api_key.py diff --git a/backend/app/bayes_ab/__init__.py b/backend/app/bayes_ab/__init__.py deleted file mode 100644 index fa07d07..0000000 --- a/backend/app/bayes_ab/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .routers import router # noqa: F401 diff --git a/backend/app/bayes_ab/models.py b/backend/app/bayes_ab/models.py deleted file mode 100644 index 8caee04..0000000 --- a/backend/app/bayes_ab/models.py +++ /dev/null @@ -1,433 +0,0 @@ -from datetime import datetime, timezone -from typing import Sequence - -from sqlalchemy import ( - Boolean, - Float, - ForeignKey, - and_, - delete, - select, -) -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import Mapped, mapped_column, relationship - -from ..models import ( - ArmBaseDB, - DrawsBaseDB, - ExperimentBaseDB, - NotificationsDB, -) -from ..schemas import ObservationType -from .schemas import BayesianAB - - -class BayesianABDB(ExperimentBaseDB): - """ - ORM for managing experiments. - """ - - __tablename__ = "bayes_ab_experiments" - - experiment_id: Mapped[int] = mapped_column( - ForeignKey("experiments_base.experiment_id", ondelete="CASCADE"), - primary_key=True, - nullable=False, - ) - - arms: Mapped[list["BayesianABArmDB"]] = relationship( - "BayesianABArmDB", back_populates="experiment", lazy="selectin" - ) - - draws: Mapped[list["BayesianABDrawDB"]] = relationship( - "BayesianABDrawDB", back_populates="experiment", lazy="joined" - ) - - __mapper_args__ = {"polymorphic_identity": "bayes_ab_experiments"} - - def to_dict(self) -> dict: - """ - Convert the ORM object to a dictionary. - """ - return { - "experiment_id": self.experiment_id, - "user_id": self.user_id, - "workspace_id": self.workspace_id, - "name": self.name, - "description": self.description, - "sticky_assignment": self.sticky_assignment, - "auto_fail": self.auto_fail, - "auto_fail_value": self.auto_fail_value, - "auto_fail_unit": self.auto_fail_unit, - "created_datetime_utc": self.created_datetime_utc, - "is_active": self.is_active, - "n_trials": self.n_trials, - "arms": [arm.to_dict() for arm in self.arms], - "prior_type": self.prior_type, - "reward_type": self.reward_type, - } - - -class BayesianABArmDB(ArmBaseDB): - """ - ORM for managing arms. - """ - - __tablename__ = "bayes_ab_arms" - - arm_id: Mapped[int] = mapped_column( - ForeignKey("arms_base.arm_id", ondelete="CASCADE"), - primary_key=True, - nullable=False, - ) - - # prior variables for AB arms - mu_init: Mapped[float] = mapped_column(Float, nullable=False) - sigma_init: Mapped[float] = mapped_column(Float, nullable=False) - mu: Mapped[float] = mapped_column(Float, nullable=False) - sigma: Mapped[float] = mapped_column(Float, nullable=False) - is_treatment_arm: Mapped[bool] = mapped_column( - Boolean, nullable=False, default=False - ) - - experiment: Mapped[BayesianABDB] = relationship( - "BayesianABDB", back_populates="arms", lazy="joined" - ) - draws: Mapped[list["BayesianABDrawDB"]] = relationship( - "BayesianABDrawDB", back_populates="arm", lazy="joined" - ) - - __mapper_args__ = {"polymorphic_identity": "bayes_ab_arms"} - - def to_dict(self) -> dict: - """ - Convert the ORM object to a dictionary. - """ - return { - "arm_id": self.arm_id, - "name": self.name, - "description": self.description, - "mu_init": self.mu_init, - "sigma_init": self.sigma_init, - "mu": self.mu, - "sigma": self.sigma, - "is_treatment_arm": self.is_treatment_arm, - "draws": [draw.to_dict() for draw in self.draws], - } - - -class BayesianABDrawDB(DrawsBaseDB): - """ - ORM for managing draws of AB experiment. - """ - - __tablename__ = "bayes_ab_draws" - - draw_id: Mapped[str] = mapped_column( # Changed from int to str - ForeignKey("draws_base.draw_id", ondelete="CASCADE"), - primary_key=True, - nullable=False, - ) - - arm: Mapped[BayesianABArmDB] = relationship( - "BayesianABArmDB", back_populates="draws", lazy="joined" - ) - experiment: Mapped[BayesianABDB] = relationship( - "BayesianABDB", back_populates="draws", lazy="joined" - ) - - __mapper_args__ = {"polymorphic_identity": "bayes_ab_draws"} - - def to_dict(self) -> dict: - """ - Convert the ORM object to a dictionary. - """ - return { - "draw_id": self.draw_id, - "client_id": self.client_id, - "draw_datetime_utc": self.draw_datetime_utc, - "arm_id": self.arm_id, - "experiment_id": self.experiment_id, - "user_id": self.user_id, - "reward": self.reward, - "observation_type": self.observation_type, - "observed_datetime_utc": self.observed_datetime_utc, - } - - -async def save_bayes_ab_to_db( - ab_experiment: BayesianAB, - user_id: int, - workspace_id: int, - asession: AsyncSession, -) -> BayesianABDB: - """ - Save the A/B experiment to the database. - """ - arms = [ - BayesianABArmDB( - name=arm.name, - description=arm.description, - mu_init=arm.mu_init, - sigma_init=arm.sigma_init, - n_outcomes=arm.n_outcomes, - is_treatment_arm=arm.is_treatment_arm, - mu=arm.mu_init, - sigma=arm.sigma_init, - user_id=user_id, - ) - for arm in ab_experiment.arms - ] - - bayes_ab_db = BayesianABDB( - name=ab_experiment.name, - description=ab_experiment.description, - user_id=user_id, - workspace_id=workspace_id, - is_active=ab_experiment.is_active, - created_datetime_utc=datetime.now(timezone.utc), - n_trials=0, - arms=arms, - sticky_assignment=ab_experiment.sticky_assignment, - auto_fail=ab_experiment.auto_fail, - auto_fail_value=ab_experiment.auto_fail_value, - auto_fail_unit=ab_experiment.auto_fail_unit, - prior_type=ab_experiment.prior_type.value, - reward_type=ab_experiment.reward_type.value, - ) - - asession.add(bayes_ab_db) - await asession.commit() - await asession.refresh(bayes_ab_db) - - return bayes_ab_db - - -async def get_all_bayes_ab_experiments( - workspace_id: int, - asession: AsyncSession, -) -> Sequence[BayesianABDB]: - """ - Get all the A/B experiments from the database for a specific workspace. - """ - stmt = ( - select(BayesianABDB) - .where(BayesianABDB.workspace_id == workspace_id) - .order_by(BayesianABDB.experiment_id) - ) - result = await asession.execute(stmt) - return result.unique().scalars().all() - - -async def get_bayes_ab_experiment_by_id( - experiment_id: int, - workspace_id: int, - asession: AsyncSession, -) -> BayesianABDB | None: - """ - Get the A/B experiment by id from a specific workspace. - """ - conditions = [ - BayesianABDB.workspace_id == workspace_id, - BayesianABDB.experiment_id == experiment_id, - ] - - stmt = select(BayesianABDB).where(and_(*conditions)) - result = await asession.execute(stmt) - return result.unique().scalar_one_or_none() - - -async def delete_bayes_ab_experiment_by_id( - experiment_id: int, - workspace_id: int, - asession: AsyncSession, -) -> None: - """ - Delete the A/B experiment by id from a specific workspace. - """ - stmt = delete(BayesianABDB).where( - and_( - BayesianABDB.workspace_id == workspace_id, - BayesianABDB.experiment_id == experiment_id, - BayesianABDB.experiment_id == ExperimentBaseDB.experiment_id, - ) - ) - await asession.execute(stmt) - - stmt = delete(NotificationsDB).where( - NotificationsDB.experiment_id == experiment_id, - ) - await asession.execute(stmt) - - stmt = delete(BayesianABDrawDB).where( - and_( - BayesianABDrawDB.draw_id == DrawsBaseDB.draw_id, - BayesianABDrawDB.experiment_id == experiment_id, - ) - ) - await asession.execute(stmt) - - stmt = delete(BayesianABArmDB).where( - and_( - BayesianABArmDB.arm_id == ArmBaseDB.arm_id, - BayesianABArmDB.experiment_id == experiment_id, - ) - ) - await asession.execute(stmt) - - await asession.commit() - return None - - -async def save_bayes_ab_observation_to_db( - draw: BayesianABDrawDB, - reward: float, - asession: AsyncSession, - observation_type: ObservationType = ObservationType.AUTO, -) -> BayesianABDrawDB: - """ - Save the A/B observation to the database. - """ - draw.reward = reward - draw.observed_datetime_utc = datetime.now(timezone.utc) - draw.observation_type = observation_type - - await asession.commit() - await asession.refresh(draw) - - return draw - - -async def save_bayes_ab_draw_to_db( - experiment_id: int, - arm_id: int, - draw_id: str, - client_id: str | None, - user_id: int | None, - asession: AsyncSession, - workspace_id: int | None, -) -> BayesianABDrawDB: - """ - Save a draw to the database - """ - # If user_id is not provided but needed, get it from the experiment - if user_id is None and workspace_id is not None: - experiment = await get_bayes_ab_experiment_by_id( - experiment_id=experiment_id, - workspace_id=workspace_id, - asession=asession, - ) - if experiment: - user_id = experiment.user_id - else: - raise ValueError(f"Experiment with id {experiment_id} not found") - - if user_id is None: - raise ValueError("User ID must be provided or derivable from experiment") - - draw_datetime_utc: datetime = datetime.now(timezone.utc) - - draw = BayesianABDrawDB( - draw_id=draw_id, - client_id=client_id, - experiment_id=experiment_id, - user_id=user_id, - arm_id=arm_id, - draw_datetime_utc=draw_datetime_utc, - ) - - asession.add(draw) - await asession.commit() - await asession.refresh(draw) - - return draw - - -async def get_bayes_ab_obs_by_experiment_arm_id( - experiment_id: int, - arm_id: int, - asession: AsyncSession, -) -> Sequence[BayesianABDrawDB]: - """ - Get the observations of a specific arm in an A/B experiment. - """ - stmt = ( - select(BayesianABDrawDB) - .where( - and_( - BayesianABDrawDB.experiment_id == experiment_id, - BayesianABDrawDB.arm_id == arm_id, - BayesianABDrawDB.reward.is_not(None), - ) - ) - .order_by(BayesianABDrawDB.observed_datetime_utc) - ) - - result = await asession.execute(stmt) - return result.unique().scalars().all() - - -async def get_bayes_ab_obs_by_experiment_id( - experiment_id: int, - workspace_id: int, - asession: AsyncSession, -) -> Sequence[BayesianABDrawDB]: - """ - Get the observations of the A/B experiment. - Verified to belong to the specified workspace. - """ - # First, verify experiment belongs to the workspace - experiment = await get_bayes_ab_experiment_by_id( - experiment_id=experiment_id, - workspace_id=workspace_id, - asession=asession, - ) - - if experiment is None: - # Return empty list if experiment doesn't exist or doesn't belong to workspace - return [] - - # Get observations for this experiment - stmt = ( - select(BayesianABDrawDB) - .where( - and_( - BayesianABDrawDB.experiment_id == experiment_id, - BayesianABDrawDB.reward.is_not(None), - ) - ) - .order_by(BayesianABDrawDB.observed_datetime_utc) - ) - - result = await asession.execute(stmt) - return result.unique().scalars().all() - - -async def get_bayes_ab_draw_by_id( - draw_id: str, asession: AsyncSession -) -> BayesianABDrawDB | None: - """ - Get a draw by its ID - """ - statement = select(BayesianABDrawDB).where(BayesianABDrawDB.draw_id == draw_id) - result = await asession.execute(statement) - - return result.unique().scalar_one_or_none() - - -async def get_bayes_ab_draw_by_client_id( - client_id: str, experiment_id: int, asession: AsyncSession -) -> BayesianABDrawDB | None: - """ - Get a draw by its client ID for a specific experiment. - """ - statement = select(BayesianABDrawDB).where( - and_( - BayesianABDrawDB.client_id == client_id, - BayesianABDrawDB.client_id.is_not(None), - BayesianABDrawDB.experiment_id == experiment_id, - ) - ) - result = await asession.execute(statement) - - return result.unique().scalars().first() diff --git a/backend/app/bayes_ab/observation.py b/backend/app/bayes_ab/observation.py deleted file mode 100644 index 212dc91..0000000 --- a/backend/app/bayes_ab/observation.py +++ /dev/null @@ -1,75 +0,0 @@ -from datetime import datetime, timezone - -from fastapi.exceptions import HTTPException -from sqlalchemy.ext.asyncio import AsyncSession - -from ..schemas import ObservationType, Outcome, RewardLikelihood -from .models import ( - BayesianABArmDB, - BayesianABDB, - BayesianABDrawDB, - save_bayes_ab_observation_to_db, -) -from .schemas import ( - BayesABArmResponse, - BayesianABSample, -) - - -async def update_based_on_outcome( - experiment: BayesianABDB, - draw: BayesianABDrawDB, - outcome: float, - asession: AsyncSession, - observation: ObservationType, -) -> BayesABArmResponse: - """ - Update the arm parameters based on the outcome. - - This is a helper function to allow `auto_fail` job to call - it as well. - """ - update_experiment_metadata(experiment) - - arm = get_arm_from_experiment(experiment, draw.arm_id) - arm.n_outcomes += 1 - - experiment_data = BayesianABSample.model_validate(experiment) - if experiment_data.reward_type == RewardLikelihood.BERNOULLI: - Outcome(outcome) # Check if reward is 0 or 1 - - await save_updated_data(arm, draw, outcome, asession) - - return BayesABArmResponse.model_validate(arm) - - -def update_experiment_metadata(experiment: BayesianABDB) -> None: - """ - Update the experiment metadata with new information. - """ - experiment.n_trials += 1 - experiment.last_trial_datetime_utc = datetime.now(tz=timezone.utc) - - -def get_arm_from_experiment(experiment: BayesianABDB, arm_id: int) -> BayesianABArmDB: - """ - Get and validate the arm from the experiment. - """ - arms = [a for a in experiment.arms if a.arm_id == arm_id] - if not arms: - raise HTTPException(status_code=404, detail=f"Arm with id {arm_id} not found") - return arms[0] - - -async def save_updated_data( - arm: BayesianABArmDB, - draw: BayesianABDrawDB, - outcome: float, - asession: AsyncSession, -) -> None: - """ - Save the updated data to the database. - """ - asession.add(arm) - await asession.commit() - await save_bayes_ab_observation_to_db(draw, outcome, asession) diff --git a/backend/app/bayes_ab/routers.py b/backend/app/bayes_ab/routers.py deleted file mode 100644 index c1041a2..0000000 --- a/backend/app/bayes_ab/routers.py +++ /dev/null @@ -1,524 +0,0 @@ -from typing import Annotated, Optional -from uuid import uuid4 - -from fastapi import APIRouter, Depends -from fastapi.exceptions import HTTPException -from sqlalchemy.ext.asyncio import AsyncSession - -from ..auth.dependencies import ( - authenticate_workspace_key, - get_verified_user, - require_admin_role, -) -from ..database import get_async_session -from ..models import get_notifications_from_db, save_notifications_to_db -from ..schemas import NotificationsResponse, ObservationType -from ..users.models import UserDB -from ..workspaces.models import ( - WorkspaceDB, - get_user_default_workspace, -) -from .models import ( - BayesianABDB, - BayesianABDrawDB, - delete_bayes_ab_experiment_by_id, - get_all_bayes_ab_experiments, - get_bayes_ab_draw_by_client_id, - get_bayes_ab_draw_by_id, - get_bayes_ab_experiment_by_id, - get_bayes_ab_obs_by_experiment_id, - save_bayes_ab_draw_to_db, - save_bayes_ab_to_db, -) -from .observation import update_based_on_outcome -from .sampling_utils import choose_arm, update_arm_params -from .schemas import ( - BayesABArmResponse, - BayesianAB, - BayesianABDrawResponse, - BayesianABObservationResponse, - BayesianABResponse, - BayesianABSample, -) - -router = APIRouter(prefix="/bayes_ab", tags=["Bayesian A/B Testing"]) - - -@router.post("/", response_model=BayesianABResponse) -async def create_ab_experiment( - experiment: BayesianAB, - user_db: Annotated[UserDB, Depends(require_admin_role)], - asession: AsyncSession = Depends(get_async_session), -) -> BayesianABResponse: - """ - Create a new experiment in the user's current workspace. - """ - workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - - if workspace_db is None: - raise HTTPException( - status_code=404, - detail="Workspace not found for the user.", - ) - - bayes_ab = await save_bayes_ab_to_db( - experiment, user_db.user_id, workspace_db.workspace_id, asession - ) - - notifications = await save_notifications_to_db( - experiment_id=bayes_ab.experiment_id, - user_id=user_db.user_id, - notifications=experiment.notifications, - asession=asession, - ) - - bayes_ab_dict = bayes_ab.to_dict() - bayes_ab_dict["notifications"] = [n.to_dict() for n in notifications] - - return BayesianABResponse.model_validate(bayes_ab_dict) - - -@router.get("/", response_model=list[BayesianABResponse]) -async def get_bayes_abs( - user_db: Annotated[UserDB, Depends(get_verified_user)], - asession: AsyncSession = Depends(get_async_session), -) -> list[BayesianABResponse]: - """ - Get details of all experiments in the user's current workspace. - """ - workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - - if workspace_db is None: - raise HTTPException( - status_code=404, - detail="Workspace not found for the user.", - ) - - experiments = await get_all_bayes_ab_experiments( - workspace_db.workspace_id, asession - ) - - all_experiments = [] - for exp in experiments: - exp_dict = exp.to_dict() - exp_dict["notifications"] = [ - n.to_dict() - for n in await get_notifications_from_db( - exp.experiment_id, exp.user_id, asession - ) - ] - all_experiments.append( - BayesianABResponse.model_validate( - { - **exp_dict, - "notifications": [ - NotificationsResponse(**n) for n in exp_dict["notifications"] - ], - } - ) - ) - return all_experiments - - -@router.get("/{experiment_id}", response_model=BayesianABResponse) -async def get_bayes_ab( - experiment_id: int, - user_db: Annotated[UserDB, Depends(get_verified_user)], - asession: AsyncSession = Depends(get_async_session), -) -> BayesianABResponse: - """ - Get details of experiment with the provided `experiment_id`. - """ - workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - - if workspace_db is None: - raise HTTPException( - status_code=404, - detail="Workspace not found for the user.", - ) - - experiment = await get_bayes_ab_experiment_by_id( - experiment_id, workspace_db.workspace_id, asession - ) - - if experiment is None: - raise HTTPException( - status_code=404, detail=f"Experiment with id {experiment_id} not found" - ) - - experiment_dict = experiment.to_dict() - experiment_dict["notifications"] = [ - n.to_dict() - for n in await get_notifications_from_db( - experiment.experiment_id, experiment.user_id, asession - ) - ] - - return BayesianABResponse.model_validate(experiment_dict) - - -@router.delete("/{experiment_id}", response_model=dict) -async def delete_bayes_ab( - experiment_id: int, - user_db: Annotated[UserDB, Depends(require_admin_role)], - asession: AsyncSession = Depends(get_async_session), -) -> dict: - """ - Delete the experiment with the provided `experiment_id`. - """ - try: - workspace_db = await get_user_default_workspace( - asession=asession, user_db=user_db - ) - - if workspace_db is None: - raise HTTPException( - status_code=404, - detail="Workspace not found for the user.", - ) - - experiment = await get_bayes_ab_experiment_by_id( - experiment_id, workspace_db.workspace_id, asession - ) - if experiment is None: - raise HTTPException( - status_code=404, detail=f"Experiment with id {experiment_id} not found" - ) - - await delete_bayes_ab_experiment_by_id( - experiment_id, workspace_db.workspace_id, asession - ) - - return {"message": f"Experiment with id {experiment_id} deleted successfully."} - except Exception as e: - raise HTTPException(status_code=500, detail=f"Error: {e}") from e - - -@router.get("/{experiment_id}/draw", response_model=BayesianABDrawResponse) -async def draw_arm( - experiment_id: int, - draw_id: Optional[str] = None, - client_id: Optional[str] = None, - workspace_db: WorkspaceDB = Depends(authenticate_workspace_key), - asession: AsyncSession = Depends(get_async_session), -) -> BayesianABDrawResponse: - """ - Get which arm to pull next for provided experiment. - """ - # Get workspace from user context - workspace_id = workspace_db.workspace_id - - experiment = await get_bayes_ab_experiment_by_id( - experiment_id, workspace_id, asession - ) - - if experiment is None: - raise HTTPException( - status_code=404, detail=f"Experiment with id {experiment_id} not found" - ) - - if experiment.sticky_assignment and not client_id: - raise HTTPException( - status_code=400, - detail="Client ID is required for sticky assignment.", - ) - - experiment_data = BayesianABSample.model_validate(experiment) - chosen_arm = choose_arm(experiment=experiment_data) - chosen_arm_id = experiment.arms[chosen_arm].arm_id - if experiment.sticky_assignment and client_id: - # Check if the client_id is already assigned to an arm - previous_draw = await get_bayes_ab_draw_by_client_id( - client_id=client_id, - experiment_id=experiment_id, - asession=asession, - ) - if previous_draw: - chosen_arm_id = previous_draw.arm_id - - # Check for existing draws - if draw_id is None: - draw_id = str(uuid4()) - - existing_draw = await get_bayes_ab_draw_by_id(draw_id=draw_id, asession=asession) - if existing_draw: - raise HTTPException( - status_code=400, - detail=f"Draw with id {draw_id} already exists for \ - experiment {experiment_id}", - ) - - try: - await save_bayes_ab_draw_to_db( - experiment_id=experiment.experiment_id, - arm_id=chosen_arm_id, - draw_id=draw_id, - client_id=client_id, - user_id=None, - asession=asession, - workspace_id=workspace_id, - ) - except Exception as e: - raise HTTPException( - status_code=500, - detail=f"Error saving draw to database: {e}", - ) from e - - return BayesianABDrawResponse.model_validate( - { - "draw_id": draw_id, - "client_id": client_id, - "arm": BayesABArmResponse.model_validate( - [arm for arm in experiment.arms if arm.arm_id == chosen_arm_id][0], - ), - } - ) - - -@router.put("/{experiment_id}/{draw_id}/{outcome}", response_model=BayesABArmResponse) -async def save_observation_for_arm( - experiment_id: int, - draw_id: str, - outcome: float, - workspace_db: WorkspaceDB = Depends(authenticate_workspace_key), - asession: AsyncSession = Depends(get_async_session), -) -> BayesABArmResponse: - """ - Update the arm with the provided `arm_id` for the given - `experiment_id` based on the `outcome`. - """ - # Get workspace from user context - workspace_id = workspace_db.workspace_id - - # Get and validate experiment - experiment, draw = await validate_experiment_and_draw( - experiment_id=experiment_id, - draw_id=draw_id, - workspace_id=workspace_id, - asession=asession, - ) - - return await update_based_on_outcome( - experiment=experiment, - draw=draw, - outcome=outcome, - asession=asession, - observation=ObservationType.USER, - ) - - -@router.get( - "/{experiment_id}/outcomes", - response_model=list[BayesianABObservationResponse], -) -async def get_outcomes( - experiment_id: int, - workspace_db: WorkspaceDB = Depends(authenticate_workspace_key), - asession: AsyncSession = Depends(get_async_session), -) -> list[BayesianABObservationResponse]: - """ - Get the outcomes for the experiment. - """ - # Get workspace from user context - workspace_id = workspace_db.workspace_id - - experiment = await get_bayes_ab_experiment_by_id( - experiment_id, workspace_id, asession - ) - if not experiment: - raise HTTPException( - status_code=404, detail=f"Experiment with id {experiment_id} not found" - ) - - rewards = await get_bayes_ab_obs_by_experiment_id( - experiment_id=experiment.experiment_id, - workspace_id=workspace_id, - asession=asession, - ) - - return [BayesianABObservationResponse.model_validate(reward) for reward in rewards] - - -@router.get( - "/{experiment_id}/arms", - response_model=list[BayesABArmResponse], -) -async def update_arms( - experiment_id: int, - workspace_db: WorkspaceDB = Depends(authenticate_workspace_key), - asession: AsyncSession = Depends(get_async_session), -) -> list[BayesABArmResponse]: - """ - Get the outcomes for the experiment. - """ - # Get workspace from user context - workspace_id = workspace_db.workspace_id - - # Check experiment params - experiment = await get_bayes_ab_experiment_by_id( - experiment_id, workspace_id, asession - ) - if not experiment: - raise HTTPException( - status_code=404, detail=f"Experiment with id {experiment_id} not found" - ) - - # Prepare data for arms update - ( - rewards, - treatments, - treatment_mu, - treatment_sigma, - control_mu, - control_sigma, - ) = await prepare_data_for_arms_update( - experiment=experiment, - workspace_id=workspace_id, - asession=asession, - ) - - # Make updates - arms_data = await make_updates_to_arms( - experiment=experiment, - treatment_mu=treatment_mu, - treatment_sigma=treatment_sigma, - control_mu=control_mu, - control_sigma=control_sigma, - rewards=rewards, - treatments=treatments, - asession=asession, - ) - - return arms_data - - -# ---- Helper functions ---- - - -async def validate_experiment_and_draw( - experiment_id: int, - draw_id: str, - workspace_id: int, - asession: AsyncSession, -) -> tuple[BayesianABDB, BayesianABDrawDB]: - """Validate the experiment and draw""" - experiment = await get_bayes_ab_experiment_by_id( - experiment_id, workspace_id, asession - ) - if experiment is None: - raise HTTPException( - status_code=404, detail=f"Experiment with id {experiment_id} not found" - ) - - draw = await get_bayes_ab_draw_by_id(draw_id=draw_id, asession=asession) - if draw is None: - raise HTTPException(status_code=404, detail=f"Draw with id {draw_id} not found") - - if draw.experiment_id != experiment_id: - raise HTTPException( - status_code=400, - detail=( - f"Draw with id {draw_id} does not belong " - f"to experiment with id {experiment_id}", - ), - ) - - if draw.reward is not None: - raise HTTPException( - status_code=400, - detail=f"Draw with id {draw_id} already has an outcome.", - ) - - return experiment, draw - - -async def prepare_data_for_arms_update( - experiment: BayesianABDB, - workspace_id: int, - asession: AsyncSession, -) -> tuple[list[float], list[float], float, float, float, float]: - """ - Prepare the data for arm update. - """ - # Get observations - observations = await get_bayes_ab_obs_by_experiment_id( - experiment_id=experiment.experiment_id, - workspace_id=workspace_id, - asession=asession, - ) - - if not observations: - raise HTTPException( - status_code=404, - detail=f"No observations found for experiment {experiment.experiment_id}", - ) - - rewards = [obs.reward for obs in observations] - - # Get treatment and control arms - arms_dict = { - arm.arm_id: 1.0 if arm.is_treatment_arm else 0.0 for arm in experiment.arms - } - - # Get params - treatment_mu, treatment_sigma = [ - (arm.mu_init, arm.sigma_init) for arm in experiment.arms if arm.is_treatment_arm - ][0] - control_mu, control_sigma = [ - (arm.mu_init, arm.sigma_init) - for arm in experiment.arms - if not arm.is_treatment_arm - ][0] - - treatments = [arms_dict[obs.arm_id] for obs in observations] - - return ( - rewards, - treatments, - treatment_mu, - treatment_sigma, - control_mu, - control_sigma, - ) - - -async def make_updates_to_arms( - experiment: BayesianABDB, - treatment_mu: float, - treatment_sigma: float, - control_mu: float, - control_sigma: float, - rewards: list[float], - treatments: list[float], - asession: AsyncSession, -) -> list[BayesABArmResponse]: - """ - Make updates to the arms of the experiment. - """ - # Make updates - experiment_data = BayesianABSample.model_validate(experiment) - new_means, new_sigmas = update_arm_params( - experiment=experiment_data, - mus=[treatment_mu, control_mu], - sigmas=[treatment_sigma, control_sigma], - rewards=rewards, - treatments=treatments, - ) - - arms_data = [] - for arm in experiment.arms: - if arm.is_treatment_arm: - arm.mu = new_means[0] - arm.sigma = new_sigmas[0] - else: - arm.mu = new_means[1] - arm.sigma = new_sigmas[1] - - asession.add(arm) - arms_data.append(BayesABArmResponse.model_validate(arm)) - - asession.add(experiment) - - await asession.commit() - - return arms_data diff --git a/backend/app/bayes_ab/sampling_utils.py b/backend/app/bayes_ab/sampling_utils.py deleted file mode 100644 index 0416f64..0000000 --- a/backend/app/bayes_ab/sampling_utils.py +++ /dev/null @@ -1,126 +0,0 @@ -import numpy as np -from scipy.optimize import minimize - -from ..schemas import ArmPriors, ContextLinkFunctions, RewardLikelihood -from .schemas import BayesianABSample - - -def _update_arms( - mus: np.ndarray, - sigmas: np.ndarray, - rewards: np.ndarray, - treatments: np.ndarray, - link_function: ContextLinkFunctions, - reward_likelihood: RewardLikelihood, - prior_type: ArmPriors, -) -> tuple[list, list]: - """ - Get arm posteriors. - - Parameters - ---------- - mu : np.ndarray - The mean of the Normal distribution. - sigma : np.ndarray - The standard deviation of the Normal distribution. - rewards : np.ndarray - The rewards. - treatments : np.ndarray - The treatments (binary-valued). - link_function : ContextLinkFunctions - The link function for parameters to rewards. - reward_likelihood : RewardLikelihood - The likelihood function of the reward. - prior_type : ArmPriors - The prior type of the arm. - """ - - # TODO we explicitly assume that there is only 1 treatment arm - def objective(treatment_effect_arms_bias: np.ndarray) -> float: - """ - Objective function for arm to outcome. - - Parameters - ---------- - treatment_effect : float - The treatment effect. - """ - treatment, control, bias = treatment_effect_arms_bias - - # log prior - log_prior = prior_type( - np.array([treatment, control]), mu=mus, covariance=np.diag(sigmas) - ) - - # log likelihood - log_likelihood = reward_likelihood( - rewards, - link_function(treatment * treatments + control * (1 - treatments) + bias), - ) - return -(log_prior + log_likelihood) - - result = minimize(objective, x0=np.zeros(3), method="L-BFGS-B", hess="2-point") - new_treatment_mean, new_control_mean, _ = result.x - new_treatment_sigma, new_control_sigma, _ = np.sqrt( - np.diag(result.hess_inv.todense()) # type: ignore - ) - return [new_treatment_mean, new_control_mean], [ - new_treatment_sigma, - new_control_sigma, - ] - - -def choose_arm(experiment: BayesianABSample) -> int: - """ - Choose arm based on posterior - - Parameters - ---------- - experiment : BayesianABSample - The experiment data containing priors and rewards for each arm. - """ - index = np.random.choice(len(experiment.arms), size=1) - return int(index[0]) - - -def update_arm_params( - experiment: BayesianABSample, - mus: list[float], - sigmas: list[float], - rewards: list[float], - treatments: list[float], -) -> tuple[list, list]: - """ - Update the arm parameters based on the reward type. - - Parameters - ---------- - experiment : BayesianABSample - The experiment data containing arms, prior type and reward - type information. - mus : list[float] - The means of the arms. - sigmas : list[float] - The standard deviations of the arms. - rewards : list[float] - The rewards. - treatments : list[float] - Which arm was applied corresponding to the reward. - """ - link_function = None - if experiment.reward_type == RewardLikelihood.NORMAL: - link_function = ContextLinkFunctions.NONE - elif experiment.reward_type == RewardLikelihood.BERNOULLI: - link_function = ContextLinkFunctions.LOGISTIC - else: - raise ValueError("Invalid reward type") - - return _update_arms( - mus=np.array(mus), - sigmas=np.array(sigmas), - rewards=np.array(rewards), - treatments=np.array(treatments), - link_function=link_function, - reward_likelihood=experiment.reward_type, - prior_type=experiment.prior_type, - ) diff --git a/backend/app/bayes_ab/schemas.py b/backend/app/bayes_ab/schemas.py deleted file mode 100644 index ef55d07..0000000 --- a/backend/app/bayes_ab/schemas.py +++ /dev/null @@ -1,145 +0,0 @@ -from datetime import datetime -from typing import Optional, Self - -from pydantic import BaseModel, ConfigDict, Field, model_validator - -from ..mab.schemas import ( - MABObservationResponse, - MultiArmedBanditBase, -) -from ..schemas import Notifications, NotificationsResponse, allowed_combos_bayes_ab - - -class BayesABArm(BaseModel): - """ - Pydantic model for a arm of the experiment. - """ - - name: str = Field( - max_length=150, - examples=["Arm 1"], - ) - description: str = Field( - max_length=500, - examples=["This is a description of the arm."], - ) - - mu_init: float = Field( - default=0.0, description="Mean parameter for treatment effect prior" - ) - sigma_init: float = Field( - default=1.0, description="Std dev parameter for treatment effect prior" - ) - n_outcomes: Optional[int] = Field( - default=0, - description="Number of outcomes for the arm", - examples=[0, 10, 15], - ) - is_treatment_arm: bool = Field( - default=True, - description="Is the arm a treatment arm", - examples=[True, False], - ) - - @model_validator(mode="after") - def check_values(self) -> Self: - """ - Check if the values are unique and set new attributes. - """ - if self.sigma_init is not None and self.sigma_init <= 0: - raise ValueError("Std dev must be greater than 0.") - return self - - model_config = ConfigDict(from_attributes=True) - - -class BayesABArmResponse(BayesABArm): - """ - Pydantic model for a response for contextual arm creation - """ - - arm_id: int - mu: float - sigma: float - model_config = ConfigDict(from_attributes=True) - - -class BayesianAB(MultiArmedBanditBase): - """ - Pydantic model for an A/B experiment. - """ - - arms: list[BayesABArm] - notifications: Notifications - model_config = ConfigDict(from_attributes=True) - - @model_validator(mode="after") - def arms_exactly_two(self) -> Self: - """ - Validate that the experiment has exactly two arms. - """ - if len(self.arms) != 2: - raise ValueError("The experiment must have at exactly two arms.") - return self - - @model_validator(mode="after") - def check_prior_reward_type_combo(self) -> Self: - """ - Validate that the prior and reward type combination is allowed. - """ - - if (self.prior_type, self.reward_type) not in allowed_combos_bayes_ab: - raise ValueError("Prior and reward type combo not supported.") - return self - - @model_validator(mode="after") - def check_treatment_and_control_arm(self) -> Self: - """ - Validate that the experiment has at least one control arm. - """ - if sum(arm.is_treatment_arm for arm in self.arms) != 1: - raise ValueError("The experiment must have one treatment and control arm.") - return self - - -class BayesianABResponse(MultiArmedBanditBase): - """ - Pydantic model for a response for an A/B experiment. - """ - - experiment_id: int - workspace_id: int - arms: list[BayesABArmResponse] - notifications: list[NotificationsResponse] - created_datetime_utc: datetime - last_trial_datetime_utc: Optional[datetime] = None - n_trials: int - - model_config = ConfigDict(from_attributes=True) - - -class BayesianABSample(MultiArmedBanditBase): - """ - Pydantic model for a sample A/B experiment. - """ - - experiment_id: int - arms: list[BayesABArmResponse] - - -class BayesianABObservationResponse(MABObservationResponse): - """ - Pydantic model for an observation response in an A/B experiment. - """ - - pass - - -class BayesianABDrawResponse(BaseModel): - """ - Pydantic model for a draw response in an A/B experiment. - """ - - draw_id: str - client_id: str | None - arm: BayesABArmResponse diff --git a/backend/app/contextual_mab/models.py b/backend/app/contextual_mab/models.py deleted file mode 100644 index 60cf723..0000000 --- a/backend/app/contextual_mab/models.py +++ /dev/null @@ -1,483 +0,0 @@ -from datetime import datetime, timezone -from typing import Sequence - -import numpy as np -from sqlalchemy import ( - Float, - ForeignKey, - Integer, - String, - and_, - delete, - select, -) -from sqlalchemy.dialects.postgresql import ARRAY -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import Mapped, mapped_column, relationship - -from ..models import ( - ArmBaseDB, - Base, - DrawsBaseDB, - ExperimentBaseDB, - NotificationsDB, -) -from ..schemas import ObservationType -from .schemas import ContextualBandit - - -class ContextualBanditDB(ExperimentBaseDB): - """ - ORM for managing contextual experiments. - """ - - __tablename__ = "contextual_mabs" - - experiment_id: Mapped[int] = mapped_column( - ForeignKey("experiments_base.experiment_id", ondelete="CASCADE"), - primary_key=True, - nullable=False, - ) - - arms: Mapped[list["ContextualArmDB"]] = relationship( - "ContextualArmDB", back_populates="experiment", lazy="joined" - ) - - contexts: Mapped[list["ContextDB"]] = relationship( - "ContextDB", back_populates="experiment", lazy="joined" - ) - - draws: Mapped[list["ContextualDrawDB"]] = relationship( - "ContextualDrawDB", back_populates="experiment", lazy="joined" - ) - - __mapper_args__ = {"polymorphic_identity": "contextual_mabs"} - - def to_dict(self) -> dict: - """ - Convert the ORM object to a dictionary. - """ - return { - "experiment_id": self.experiment_id, - "user_id": self.user_id, - "workspace_id": self.workspace_id, - "name": self.name, - "description": self.description, - "sticky_assignment": self.sticky_assignment, - "auto_fail": self.auto_fail, - "auto_fail_value": self.auto_fail_value, - "auto_fail_unit": self.auto_fail_unit, - "created_datetime_utc": self.created_datetime_utc, - "is_active": self.is_active, - "n_trials": self.n_trials, - "arms": [arm.to_dict() for arm in self.arms], - "contexts": [context.to_dict() for context in self.contexts], - "prior_type": self.prior_type, - "reward_type": self.reward_type, - } - - -class ContextualArmDB(ArmBaseDB): - """ - ORM for managing contextual arms of an experiment - """ - - __tablename__ = "contextual_arms" - - arm_id: Mapped[int] = mapped_column( - ForeignKey("arms_base.arm_id", ondelete="CASCADE"), - primary_key=True, - nullable=False, - ) - - # prior variables for CMAB arms - mu_init: Mapped[float] = mapped_column(Float, nullable=False) - sigma_init: Mapped[float] = mapped_column(Float, nullable=False) - mu: Mapped[list[float]] = mapped_column(ARRAY(Float), nullable=False) - covariance: Mapped[list[float]] = mapped_column(ARRAY(Float), nullable=False) - - experiment: Mapped[ContextualBanditDB] = relationship( - "ContextualBanditDB", back_populates="arms", lazy="joined" - ) - draws: Mapped[list["ContextualDrawDB"]] = relationship( - "ContextualDrawDB", back_populates="arm", lazy="joined" - ) - - __mapper_args__ = {"polymorphic_identity": "contextual_arms"} - - def to_dict(self) -> dict: - """ - Convert the ORM object to a dictionary. - """ - return { - "arm_id": self.arm_id, - "name": self.name, - "description": self.description, - "mu_init": self.mu_init, - "sigma_init": self.sigma_init, - "mu": self.mu, - "covariance": self.covariance, - "draws": [draw.to_dict() for draw in self.draws], - } - - -class ContextDB(Base): - """ - ORM for managing context for an experiment - """ - - __tablename__ = "contexts" - - context_id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False) - experiment_id: Mapped[int] = mapped_column( - Integer, ForeignKey("contextual_mabs.experiment_id"), nullable=False - ) - user_id: Mapped[int] = mapped_column( - Integer, ForeignKey("users.user_id"), nullable=False - ) - name: Mapped[str] = mapped_column(String(length=150), nullable=False) - description: Mapped[str] = mapped_column(String(length=500), nullable=True) - value_type: Mapped[str] = mapped_column(String(length=50), nullable=False) - - experiment: Mapped[ContextualBanditDB] = relationship( - "ContextualBanditDB", back_populates="contexts", lazy="joined" - ) - - def to_dict(self) -> dict: - """ - Convert the ORM object to a dictionary. - """ - return { - "context_id": self.context_id, - "name": self.name, - "description": self.description, - "value_type": self.value_type, - } - - -class ContextualDrawDB(DrawsBaseDB): - """ - ORM for managing draws of an experiment - """ - - __tablename__ = "contextual_draws" - - draw_id: Mapped[str] = mapped_column( - ForeignKey("draws_base.draw_id", ondelete="CASCADE"), - primary_key=True, - nullable=False, - ) - - context_val: Mapped[list] = mapped_column(ARRAY(Float), nullable=False) - arm: Mapped[ContextualArmDB] = relationship( - "ContextualArmDB", back_populates="draws", lazy="joined" - ) - experiment: Mapped[ContextualBanditDB] = relationship( - "ContextualBanditDB", back_populates="draws", lazy="joined" - ) - - __mapper_args__ = {"polymorphic_identity": "contextual_draws"} - - def to_dict(self) -> dict: - """ - Convert the ORM object to a dictionary. - """ - return { - "draw_id": self.draw_id, - "client_id": self.client_id, - "draw_datetime_utc": self.draw_datetime_utc, - "context_val": self.context_val, - "arm_id": self.arm_id, - "experiment_id": self.experiment_id, - "user_id": self.user_id, - "reward": self.reward, - "observation_type": self.observation_type, - "observed_datetime_utc": self.observed_datetime_utc, - } - - -async def save_contextual_mab_to_db( - experiment: ContextualBandit, - user_id: int, - workspace_id: int, - asession: AsyncSession, -) -> ContextualBanditDB: - """ - Save the experiment to the database. - """ - contexts = [ - ContextDB( - name=context.name, - description=context.description, - value_type=context.value_type.value, - user_id=user_id, - ) - for context in experiment.contexts - ] - arms = [] - for arm in experiment.arms: - arms.append( - ContextualArmDB( - name=arm.name, - description=arm.description, - mu_init=arm.mu_init, - sigma_init=arm.sigma_init, - mu=(np.ones(len(experiment.contexts)) * arm.mu_init).tolist(), - covariance=( - np.identity(len(experiment.contexts)) * arm.sigma_init - ).tolist(), - user_id=user_id, - n_outcomes=arm.n_outcomes, - ) - ) - - experiment_db = ContextualBanditDB( - name=experiment.name, - description=experiment.description, - user_id=user_id, - workspace_id=workspace_id, - is_active=experiment.is_active, - created_datetime_utc=datetime.now(timezone.utc), - n_trials=0, - arms=arms, - sticky_assignment=experiment.sticky_assignment, - auto_fail=experiment.auto_fail, - auto_fail_value=experiment.auto_fail_value, - auto_fail_unit=experiment.auto_fail_unit, - contexts=contexts, - prior_type=experiment.prior_type.value, - reward_type=experiment.reward_type.value, - ) - - asession.add(experiment_db) - await asession.commit() - await asession.refresh(experiment_db) - - return experiment_db - - -async def get_all_contextual_mabs( - workspace_id: int, - asession: AsyncSession, -) -> Sequence[ContextualBanditDB]: - """ - Get all the contextual experiments from the database for a specific workspace. - """ - statement = ( - select(ContextualBanditDB) - .where(ContextualBanditDB.workspace_id == workspace_id) - .order_by(ContextualBanditDB.experiment_id) - ) - - return (await asession.execute(statement)).unique().scalars().all() - - -async def get_contextual_mab_by_id( - experiment_id: int, workspace_id: int, asession: AsyncSession -) -> ContextualBanditDB | None: - """ - Get the contextual experiment by id from a specific workspace. - """ - condition = [ - ContextualBanditDB.experiment_id == experiment_id, - ContextualBanditDB.workspace_id == workspace_id, - ] - - statement = select(ContextualBanditDB).where(*condition) - result = await asession.execute(statement) - - return result.unique().scalar_one_or_none() - - -async def delete_contextual_mab_by_id( - experiment_id: int, workspace_id: int, asession: AsyncSession -) -> None: - """ - Delete the contextual experiment by id. - """ - await asession.execute( - delete(NotificationsDB).where(NotificationsDB.experiment_id == experiment_id) - ) - - await asession.execute( - delete(ContextualDrawDB).where(ContextualDrawDB.experiment_id == experiment_id) - ) - - await asession.execute( - delete(ContextDB).where(ContextDB.experiment_id == experiment_id) - ) - - await asession.execute( - delete(ContextualArmDB).where(ContextualArmDB.experiment_id == experiment_id) - ) - - await asession.execute( - delete(ContextualBanditDB).where( - and_( - ContextualBanditDB.workspace_id == workspace_id, - ContextualBanditDB.experiment_id == experiment_id, - ContextualBanditDB.experiment_id == ExperimentBaseDB.experiment_id, - ) - ) - ) - await asession.commit() - return None - - -async def save_contextual_obs_to_db( - draw: ContextualDrawDB, - reward: float, - asession: AsyncSession, - observation_type: ObservationType, -) -> ContextualDrawDB: - """ - Save the observation to the database. - """ - draw.reward = reward - draw.observed_datetime_utc = datetime.now(timezone.utc) - draw.observation_type = observation_type # Remove .value, pass enum directly - - await asession.commit() - await asession.refresh(draw) - - return draw - - -async def get_contextual_obs_by_experiment_arm_id( - experiment_id: int, - arm_id: int, - asession: AsyncSession, -) -> Sequence[ContextualDrawDB]: - """Get the observations for a specific arm of an experiment.""" - statement = ( - select(ContextualDrawDB) - .where( - and_( - ContextualDrawDB.experiment_id == experiment_id, - ContextualDrawDB.arm_id == arm_id, - ContextualDrawDB.reward.is_not(None), - ) - ) - .order_by(ContextualDrawDB.observed_datetime_utc) - ) - - result = await asession.execute(statement) - return result.unique().scalars().all() - - -async def get_all_contextual_obs_by_experiment_id( - experiment_id: int, - workspace_id: int, - asession: AsyncSession, -) -> Sequence[ContextualDrawDB]: - """ - Get all observations for an experiment, - verified to belong to the specified workspace. - """ - # First, verify experiment belongs to the workspace - experiment = await get_contextual_mab_by_id( - experiment_id=experiment_id, - workspace_id=workspace_id, - asession=asession, - ) - - if experiment is None: - # Return empty list if experiment doesn't exist or doesn't belong to workspace - return [] - - # Get all observations for this experiment - statement = ( - select(ContextualDrawDB) - .where( - and_( - ContextualDrawDB.experiment_id == experiment_id, - ContextualDrawDB.reward.is_not(None), - ) - ) - .order_by(ContextualDrawDB.observed_datetime_utc) - ) - - result = await asession.execute(statement) - return result.unique().scalars().all() - - -async def get_draw_by_id( - draw_id: str, asession: AsyncSession -) -> ContextualDrawDB | None: - """ - Get the draw by its ID, which should be unique across the system. - """ - statement = select(ContextualDrawDB).where(ContextualDrawDB.draw_id == draw_id) - result = await asession.execute(statement) - return result.unique().scalar_one_or_none() - - -async def get_draw_by_client_id( - client_id: str, experiment_id: int, asession: AsyncSession -) -> ContextualDrawDB | None: - """ - Get the draw by client id for a specific experiment. - """ - statement = ( - select(ContextualDrawDB) - .where(ContextualDrawDB.client_id == client_id) - .where(ContextualDrawDB.client_id.is_not(None)) - .where(ContextualDrawDB.experiment_id == experiment_id) - ) - result = await asession.execute(statement) - - return result.unique().scalars().first() - - -async def save_draw_to_db( - experiment_id: int, - arm_id: int, - context_val: list[float], - draw_id: str, - client_id: str | None, - user_id: int | None, - asession: AsyncSession, - workspace_id: int | None, -) -> ContextualDrawDB: - """ - Save the draw to the database. - """ - # If user_id is not provided but needed, get it from the experiment - if user_id is None: - if workspace_id is not None: - # Try to get experiment with workspace_id - experiment = await get_contextual_mab_by_id( - experiment_id=experiment_id, - workspace_id=workspace_id, - asession=asession, - ) - if experiment: - user_id = experiment.user_id - else: - raise ValueError(f"Experiment with id {experiment_id} not found") - else: - # Fall back to direct get if workspace_id not provided - experiment = await asession.get(ContextualBanditDB, experiment_id) - if experiment: - user_id = experiment.user_id - else: - raise ValueError(f"Experiment with id {experiment_id} not found") - - if user_id is None: - raise ValueError("User ID must be provided or derivable from experiment") - - draw_db = ContextualDrawDB( - draw_id=draw_id, - client_id=client_id, - arm_id=arm_id, - experiment_id=experiment_id, - user_id=user_id, - context_val=context_val, - draw_datetime_utc=datetime.now(timezone.utc), - ) - - asession.add(draw_db) - await asession.commit() - await asession.refresh(draw_db) - - return draw_db diff --git a/backend/app/contextual_mab/observation.py b/backend/app/contextual_mab/observation.py deleted file mode 100644 index e655bbf..0000000 --- a/backend/app/contextual_mab/observation.py +++ /dev/null @@ -1,126 +0,0 @@ -from datetime import datetime, timezone -from typing import Sequence - -import numpy as np -from fastapi.exceptions import HTTPException -from sqlalchemy.ext.asyncio import AsyncSession - -from ..schemas import ( - ObservationType, - RewardLikelihood, -) -from .models import ( - ContextualArmDB, - ContextualBanditDB, - ContextualDrawDB, - get_contextual_obs_by_experiment_arm_id, - save_contextual_obs_to_db, -) -from .sampling_utils import update_arm_params -from .schemas import ( - ContextualArmResponse, - ContextualBanditSample, -) - - -async def update_based_on_outcome( - experiment: ContextualBanditDB, - draw: ContextualDrawDB, - reward: float, - asession: AsyncSession, - observation_type: ObservationType, -) -> ContextualArmResponse: - """ - Update the arm based on the outcome of the draw. - - This is a helper function to allow `auto_fail` job to call - it as well. - """ - - update_experiment_metadata(experiment) - - arm = get_arm_from_experiment(experiment, draw.arm_id) - arm.n_outcomes += 1 - - # Ensure reward is binary for Bernoulli reward type - if experiment.reward_type == RewardLikelihood.BERNOULLI.value: - if reward not in [0, 1]: - raise HTTPException( - status_code=400, - detail="Reward must be 0 or 1 for Bernoulli reward type.", - ) - - # Get data for arm update - all_obs, contexts, rewards = await prepare_data_for_arm_update( - experiment.experiment_id, arm.arm_id, asession, draw, reward - ) - - experiment_data = ContextualBanditSample.model_validate(experiment) - mu, covariance = update_arm_params( - arm=ContextualArmResponse.model_validate(arm), - prior_type=experiment_data.prior_type, - reward_type=experiment_data.reward_type, - context=contexts, - reward=rewards, - ) - - await save_updated_data( - arm, mu, covariance, draw, reward, observation_type, asession - ) - - return ContextualArmResponse.model_validate(arm) - - -def update_experiment_metadata(experiment: ContextualBanditDB) -> None: - """Update experiment metadata with new trial information""" - experiment.n_trials += 1 - experiment.last_trial_datetime_utc = datetime.now(tz=timezone.utc) - - -def get_arm_from_experiment( - experiment: ContextualBanditDB, arm_id: int -) -> ContextualArmDB: - """Get and validate the arm from the experiment""" - arms = [a for a in experiment.arms if a.arm_id == arm_id] - if not arms: - raise HTTPException(status_code=404, detail=f"Arm with id {arm_id} not found") - return arms[0] - - -async def prepare_data_for_arm_update( - experiment_id: int, - arm_id: int, - asession: AsyncSession, - draw: ContextualDrawDB, - reward: float, -) -> tuple[Sequence[ContextualDrawDB], list[list], list[float]]: - """Prepare the data needed for updating arm parameters""" - all_obs = await get_contextual_obs_by_experiment_arm_id( - experiment_id=experiment_id, - arm_id=arm_id, - asession=asession, - ) - - rewards = [obs.reward for obs in all_obs] + [reward] - contexts = [obs.context_val for obs in all_obs] - contexts.append(draw.context_val) - - return all_obs, contexts, rewards - - -async def save_updated_data( - arm: ContextualArmDB, - mu: np.ndarray, - covariance: np.ndarray, - draw: ContextualDrawDB, - reward: float, - observation_type: ObservationType, - asession: AsyncSession, -) -> None: - """Save the updated arm and observation data""" - arm.mu = mu.tolist() - arm.covariance = covariance.tolist() - asession.add(arm) - await asession.commit() - - await save_contextual_obs_to_db(draw, reward, asession, observation_type) diff --git a/backend/app/contextual_mab/routers.py b/backend/app/contextual_mab/routers.py deleted file mode 100644 index 08eea28..0000000 --- a/backend/app/contextual_mab/routers.py +++ /dev/null @@ -1,395 +0,0 @@ -from typing import Annotated, List, Optional -from uuid import uuid4 - -from fastapi import APIRouter, Depends -from fastapi.exceptions import HTTPException -from sqlalchemy.ext.asyncio import AsyncSession - -from ..auth.dependencies import ( - authenticate_workspace_key, - get_verified_user, - require_admin_role, -) -from ..database import get_async_session -from ..models import get_notifications_from_db, save_notifications_to_db -from ..schemas import ( - ContextType, - NotificationsResponse, - ObservationType, - Outcome, -) -from ..users.models import UserDB -from ..utils import setup_logger -from ..workspaces.models import ( - WorkspaceDB, - get_user_default_workspace, -) -from .models import ( - ContextualBanditDB, - ContextualDrawDB, - delete_contextual_mab_by_id, - get_all_contextual_mabs, - get_all_contextual_obs_by_experiment_id, - get_contextual_mab_by_id, - get_draw_by_client_id, - get_draw_by_id, - save_contextual_mab_to_db, - save_draw_to_db, -) -from .observation import update_based_on_outcome -from .sampling_utils import choose_arm -from .schemas import ( - CMABDrawResponse, - CMABObservationResponse, - ContextInput, - ContextualArmResponse, - ContextualBandit, - ContextualBanditResponse, - ContextualBanditSample, -) - -router = APIRouter(prefix="/contextual_mab", tags=["Contextual Bandits"]) - -logger = setup_logger(__name__) - - -@router.post("/", response_model=ContextualBanditResponse) -async def create_contextual_mabs( - experiment: ContextualBandit, - user_db: Annotated[UserDB, Depends(require_admin_role)], - asession: AsyncSession = Depends(get_async_session), -) -> ContextualBanditResponse | HTTPException: - """ - Create a new contextual experiment with different priors for each context. - """ - workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - - if workspace_db is None: - raise HTTPException( - status_code=404, - detail="Workspace not found. Please create a workspace first.", - ) - - cmab = await save_contextual_mab_to_db( - experiment, user_db.user_id, workspace_db.workspace_id, asession - ) - notifications = await save_notifications_to_db( - experiment_id=cmab.experiment_id, - user_id=user_db.user_id, - notifications=experiment.notifications, - asession=asession, - ) - cmab_dict = cmab.to_dict() - cmab_dict["notifications"] = [n.to_dict() for n in notifications] - return ContextualBanditResponse.model_validate(cmab_dict) - - -@router.get("/", response_model=list[ContextualBanditResponse]) -async def get_contextual_mabs( - user_db: Annotated[UserDB, Depends(get_verified_user)], - asession: AsyncSession = Depends(get_async_session), -) -> list[ContextualBanditResponse]: - """ - Get details of all experiments. - """ - workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - - if workspace_db is None: - raise HTTPException( - status_code=404, - detail="Workspace not found. Please create a workspace first.", - ) - - experiments = await get_all_contextual_mabs(workspace_db.workspace_id, asession) - all_experiments = [] - for exp in experiments: - exp_dict = exp.to_dict() - exp_dict["notifications"] = [ - n.to_dict() - for n in await get_notifications_from_db( - exp.experiment_id, exp.user_id, asession - ) - ] - all_experiments.append( - ContextualBanditResponse.model_validate( - { - **exp_dict, - "notifications": [ - NotificationsResponse.model_validate(n) - for n in exp_dict["notifications"] - ], - } - ) - ) - - return all_experiments - - -@router.get("/{experiment_id}", response_model=ContextualBanditResponse) -async def get_contextual_mab( - experiment_id: int, - user_db: Annotated[UserDB, Depends(get_verified_user)], - asession: AsyncSession = Depends(get_async_session), -) -> ContextualBanditResponse | HTTPException: - """ - Get details of experiment with the provided `experiment_id`. - """ - workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - - if workspace_db is None: - raise HTTPException( - status_code=404, - detail="Workspace not found. Please create a workspace first.", - ) - - experiment = await get_contextual_mab_by_id( - experiment_id, workspace_db.workspace_id, asession - ) - if experiment is None: - raise HTTPException( - status_code=404, detail=f"Experiment with id {experiment_id} not found" - ) - - experiment_dict = experiment.to_dict() - experiment_dict["notifications"] = [ - n.to_dict() - for n in await get_notifications_from_db( - experiment.experiment_id, experiment.user_id, asession - ) - ] - return ContextualBanditResponse.model_validate(experiment_dict) - - -@router.delete("/{experiment_id}", response_model=dict) -async def delete_contextual_mab( - experiment_id: int, - user_db: Annotated[UserDB, Depends(require_admin_role)], - asession: AsyncSession = Depends(get_async_session), -) -> dict: - """ - Delete the experiment with the provided `experiment_id`. - """ - try: - workspace_db = await get_user_default_workspace( - asession=asession, user_db=user_db - ) - - if workspace_db is None: - raise HTTPException( - status_code=404, - detail="Workspace not found. Please create a workspace first.", - ) - - experiment = await get_contextual_mab_by_id( - experiment_id, workspace_db.workspace_id, asession - ) - if experiment is None: - raise HTTPException( - status_code=404, detail=f"Experiment with id {experiment_id} not found" - ) - await delete_contextual_mab_by_id( - experiment_id, workspace_db.workspace_id, asession - ) - return {"detail": f"Experiment {experiment_id} deleted successfully."} - except Exception as e: - raise HTTPException(status_code=500, detail=f"Error: {e}") from e - - -@router.post("/{experiment_id}/draw", response_model=CMABDrawResponse) -async def draw_arm( - experiment_id: int, - context: List[ContextInput], - draw_id: Optional[str] = None, - client_id: Optional[str] = None, - workspace_db: WorkspaceDB = Depends(authenticate_workspace_key), - asession: AsyncSession = Depends(get_async_session), -) -> CMABDrawResponse: - """ - Get which arm to pull next for provided experiment. - """ - # Get workspace from user context - workspace_id = workspace_db.workspace_id - - experiment = await get_contextual_mab_by_id(experiment_id, workspace_id, asession) - - if experiment is None: - raise HTTPException( - status_code=404, detail=f"Experiment with id {experiment_id} not found" - ) - - # Check context inputs - if len(experiment.contexts) != len(context): - raise HTTPException( - status_code=400, - detail="Number of contexts provided does not match the num contexts.", - ) - experiment_data = ContextualBanditSample.model_validate(experiment) - sorted_context = list(sorted(context, key=lambda x: x.context_id)) - - try: - for c_input, c_exp in zip( - sorted_context, - sorted(experiment.contexts, key=lambda x: x.context_id), - ): - if c_exp.value_type == ContextType.BINARY.value: - Outcome(c_input.context_value) - except ValueError as e: - raise HTTPException( - status_code=400, - detail=f"Invalid context value: {e}", - ) from e - - # Generate UUID if not provided - if draw_id is None: - draw_id = str(uuid4()) - - existing_draw = await get_draw_by_id(draw_id, asession) - if existing_draw: - raise HTTPException( - status_code=400, - detail=f"Draw ID {draw_id} already exists.", - ) - - # Check if sticky assignment - if experiment.sticky_assignment and not client_id: - raise HTTPException( - status_code=400, - detail="Client ID is required for sticky assignment.", - ) - - chosen_arm = choose_arm( - experiment_data, - [c.context_value for c in sorted_context], - ) - chosen_arm_id = experiment.arms[chosen_arm].arm_id - if experiment.sticky_assignment and client_id: - previous_draw = await get_draw_by_client_id( - client_id=client_id, - experiment_id=experiment.experiment_id, - asession=asession, - ) - if previous_draw: - chosen_arm_id = previous_draw.arm_id - - try: - _ = await save_draw_to_db( - experiment_id=experiment.experiment_id, - arm_id=chosen_arm_id, - context_val=[c.context_value for c in sorted_context], - draw_id=draw_id, - client_id=client_id, - user_id=None, - asession=asession, - workspace_id=workspace_id, - ) - except Exception as e: - raise HTTPException( - status_code=500, - detail=f"Error saving draw to database: {e}", - ) from e - - return CMABDrawResponse.model_validate( - { - "draw_id": draw_id, - "client_id": client_id, - "arm": ContextualArmResponse.model_validate( - [arm for arm in experiment.arms if arm.arm_id == chosen_arm_id][0] - ), - } - ) - - -@router.put("/{experiment_id}/{draw_id}/{reward}", response_model=ContextualArmResponse) -async def update_arm( - experiment_id: int, - draw_id: str, - reward: float, - workspace_db: WorkspaceDB = Depends(authenticate_workspace_key), - asession: AsyncSession = Depends(get_async_session), -) -> ContextualArmResponse: - """ - Update the arm with the provided `arm_id` for the given - `experiment_id` based on the reward. - """ - # Get workspace from user context - workspace_id = workspace_db.workspace_id - - # Get the experiment and do checks - experiment, draw = await validate_experiment_and_draw( - experiment_id, draw_id, workspace_id, asession - ) - - return await update_based_on_outcome( - experiment, draw, reward, asession, ObservationType.USER - ) - - -@router.get( - "/{experiment_id}/outcomes", - response_model=list[CMABObservationResponse], -) -async def get_outcomes( - experiment_id: int, - workspace_db: WorkspaceDB = Depends(authenticate_workspace_key), - asession: AsyncSession = Depends(get_async_session), -) -> list[CMABObservationResponse]: - """ - Get the outcomes for the experiment. - """ - # Get workspace from user context - workspace_id = workspace_db.workspace_id - - experiment = await get_contextual_mab_by_id(experiment_id, workspace_id, asession) - if not experiment: - raise HTTPException( - status_code=404, detail=f"Experiment with id {experiment_id} not found" - ) - - observations = await get_all_contextual_obs_by_experiment_id( - experiment_id=experiment.experiment_id, - workspace_id=workspace_id, - asession=asession, - ) - return [CMABObservationResponse.model_validate(obs) for obs in observations] - - -async def validate_experiment_and_draw( - experiment_id: int, - draw_id: str, - workspace_id: int, - asession: AsyncSession, -) -> tuple[ContextualBanditDB, ContextualDrawDB]: - """ - Validate that the experiment exists in the workspace - and the draw exists for that experiment. - """ - experiment = await get_contextual_mab_by_id( - experiment_id=experiment_id, - workspace_id=workspace_id, - asession=asession, - ) - if experiment is None: - raise HTTPException( - status_code=404, detail=f"Experiment with id {experiment_id} not found" - ) - - draw = await get_draw_by_id(draw_id=draw_id, asession=asession) - if draw is None: - raise HTTPException(status_code=404, detail=f"Draw with id {draw_id} not found") - - if draw.experiment_id != experiment_id: - raise HTTPException( - status_code=400, - detail=( - f"Draw with id {draw_id} does not belong " - f"to experiment with id {experiment_id}", - ), - ) - - if draw.reward is not None: - raise HTTPException( - status_code=400, - detail=f"Draw with id {draw_id} already has a reward.", - ) - - return experiment, draw diff --git a/backend/app/contextual_mab/schemas.py b/backend/app/contextual_mab/schemas.py deleted file mode 100644 index 57baaf4..0000000 --- a/backend/app/contextual_mab/schemas.py +++ /dev/null @@ -1,268 +0,0 @@ -from datetime import datetime -from typing import Optional, Self - -from pydantic import BaseModel, ConfigDict, Field, model_validator - -from ..schemas import ( - ArmPriors, - AutoFailUnitType, - ContextType, - Notifications, - NotificationsResponse, - RewardLikelihood, - allowed_combos_cmab, -) - - -class Context(BaseModel): - """ - Pydantic model for a binary-valued context of the experiment. - """ - - name: str = Field( - description="Name of the context", - examples=["Context 1"], - ) - description: str = Field( - description="Description of the context", - examples=["This is a description of the context."], - ) - value_type: ContextType = Field( - description="Type of value the context can take", default=ContextType.BINARY - ) - model_config = ConfigDict(from_attributes=True) - - -class ContextResponse(Context): - """ - Pydantic model for an response for context creation - """ - - context_id: int - model_config = ConfigDict(from_attributes=True) - - -class ContextInput(BaseModel): - """ - Pydantic model for a context input - """ - - context_id: int - context_value: float - model_config = ConfigDict(from_attributes=True) - - -class ContextualArm(BaseModel): - """ - Pydantic model for a contextual arm of the experiment. - """ - - name: str = Field( - max_length=150, - examples=["Arm 1"], - ) - description: str = Field( - max_length=500, - examples=["This is a description of the arm."], - ) - - mu_init: float = Field( - default=0.0, - examples=[0.0, 1.2, 5.7], - description="Mean parameter for Normal prior", - ) - - sigma_init: float = Field( - default=1.0, - examples=[1.0, 0.5, 2.0], - description="Standard deviation parameter for Normal prior", - ) - n_outcomes: Optional[int] = Field( - default=0, - description="Number of outcomes for the arm", - examples=[0, 10, 15], - ) - - @model_validator(mode="after") - def check_values(self) -> Self: - """ - Check if the values are unique and set new attributes. - """ - sigma = self.sigma_init - if sigma is not None and sigma <= 0: - raise ValueError("Std dev must be greater than 0.") - return self - - model_config = ConfigDict(from_attributes=True) - - -class ContextualArmResponse(ContextualArm): - """ - Pydantic model for an response for contextual arm creation - """ - - arm_id: int - mu: list[float] - covariance: list[list[float]] - - model_config = ConfigDict(from_attributes=True) - - -class ContextualBanditBase(BaseModel): - """ - Pydantic model for a contextual experiment - Base model. - Note: Do not use this model directly. Use ContextualBandit instead. - """ - - name: str = Field( - max_length=150, - examples=["Experiment 1"], - ) - - description: str = Field( - max_length=500, - examples=["This is a description of the experiment."], - ) - - sticky_assignment: bool = Field( - description="Whether the arm assignment is sticky or not.", - default=False, - ) - - auto_fail: bool = Field( - description=( - "Whether the experiment should fail automatically after " - "a certain period if no outcome is registered." - ), - default=False, - ) - - auto_fail_value: Optional[int] = Field( - description="The time period after which the experiment should fail.", - default=None, - ) - - auto_fail_unit: Optional[AutoFailUnitType] = Field( - description="The time unit for the auto fail period.", - default=None, - ) - - reward_type: RewardLikelihood = Field( - description="The type of reward we observe from the experiment.", - default=RewardLikelihood.BERNOULLI, - ) - - prior_type: ArmPriors = Field( - description="The type of prior distribution for the arms.", - default=ArmPriors.NORMAL, - ) - - is_active: bool = True - - model_config = ConfigDict(from_attributes=True) - - -class ContextualBandit(ContextualBanditBase): - """ - Pydantic model for a contextual experiment. - """ - - arms: list[ContextualArm] - contexts: list[Context] - notifications: Notifications - - @model_validator(mode="after") - def auto_fail_unit_and_value_set(self) -> Self: - """ - Validate that the auto fail unit and value are set if auto fail is True. - """ - if self.auto_fail: - if ( - not self.auto_fail_value - or not self.auto_fail_unit - or self.auto_fail_value <= 0 - ): - raise ValueError( - ( - "Auto fail is enabled. " - "Please provide both auto_fail_value and auto_fail_unit." - ) - ) - return self - - @model_validator(mode="after") - def arms_at_least_two(self) -> Self: - """ - Validate that the experiment has at least two arms. - """ - if len(self.arms) < 2: - raise ValueError("The experiment must have at least two arms.") - return self - - @model_validator(mode="after") - def check_prior_reward_type_combo(self) -> Self: - """ - Validate that the prior and reward type combination is allowed. - """ - - if (self.prior_type, self.reward_type) not in allowed_combos_cmab: - raise ValueError("Prior and reward type combo not supported.") - return self - - model_config = ConfigDict(from_attributes=True) - - -class ContextualBanditResponse(ContextualBanditBase): - """ - Pydantic model for an response for contextual experiment creation. - Returns the id of the experiment, the arms and the contexts - """ - - experiment_id: int - workspace_id: int - arms: list[ContextualArmResponse] - contexts: list[ContextResponse] - notifications: list[NotificationsResponse] - created_datetime_utc: datetime - last_trial_datetime_utc: Optional[datetime] = None - n_trials: int - - model_config = ConfigDict(from_attributes=True) - - -class ContextualBanditSample(ContextualBanditBase): - """ - Pydantic model for a contextual experiment sample. - """ - - experiment_id: int - arms: list[ContextualArmResponse] - contexts: list[ContextResponse] - - -class CMABObservationResponse(BaseModel): - """ - Pydantic model for an response for contextual observation creation - """ - - arm_id: int - reward: float - context_val: list[float] - - draw_id: str - client_id: str | None - observed_datetime_utc: datetime - - model_config = ConfigDict(from_attributes=True) - - -class CMABDrawResponse(BaseModel): - """ - Pydantic model for an response for contextual arm draw - """ - - draw_id: str - client_id: str | None - arm: ContextualArmResponse - - model_config = ConfigDict(from_attributes=True) diff --git a/backend/app/mab/__init__.py b/backend/app/mab/__init__.py deleted file mode 100644 index fa07d07..0000000 --- a/backend/app/mab/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .routers import router # noqa: F401 diff --git a/backend/app/mab/models.py b/backend/app/mab/models.py deleted file mode 100644 index f14d483..0000000 --- a/backend/app/mab/models.py +++ /dev/null @@ -1,419 +0,0 @@ -from datetime import datetime, timezone -from typing import Sequence - -from sqlalchemy import ( - Float, - ForeignKey, - and_, - delete, - select, -) -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import Mapped, mapped_column, relationship - -from ..models import ( - ArmBaseDB, - DrawsBaseDB, - ExperimentBaseDB, - NotificationsDB, -) -from ..schemas import ObservationType -from .schemas import MultiArmedBandit - - -class MultiArmedBanditDB(ExperimentBaseDB): - """ - ORM for managing experiments. - """ - - __tablename__ = "mabs" - - experiment_id: Mapped[int] = mapped_column( - ForeignKey("experiments_base.experiment_id", ondelete="CASCADE"), - primary_key=True, - nullable=False, - ) - arms: Mapped[list["MABArmDB"]] = relationship( - "MABArmDB", back_populates="experiment", lazy="joined" - ) - - draws: Mapped[list["MABDrawDB"]] = relationship( - "MABDrawDB", back_populates="experiment", lazy="joined" - ) - - __mapper_args__ = {"polymorphic_identity": "mabs"} - - def to_dict(self) -> dict: - """ - Convert the ORM object to a dictionary. - """ - return { - "experiment_id": self.experiment_id, - "user_id": self.user_id, - "workspace_id": self.workspace_id, - "name": self.name, - "description": self.description, - "sticky_assignment": self.sticky_assignment, - "auto_fail": self.auto_fail, - "auto_fail_value": self.auto_fail_value, - "auto_fail_unit": self.auto_fail_unit, - "created_datetime_utc": self.created_datetime_utc, - "is_active": self.is_active, - "n_trials": self.n_trials, - "arms": [arm.to_dict() for arm in self.arms], - "prior_type": self.prior_type, - "reward_type": self.reward_type, - } - - -class MABArmDB(ArmBaseDB): - """ - ORM for managing arms of an experiment - """ - - __tablename__ = "mab_arms" - - arm_id: Mapped[int] = mapped_column( - ForeignKey("arms_base.arm_id", ondelete="CASCADE"), - primary_key=True, - nullable=False, - ) - - # prior variables for MAB arms - alpha: Mapped[float] = mapped_column(Float, nullable=True) - beta: Mapped[float] = mapped_column(Float, nullable=True) - mu: Mapped[float] = mapped_column(Float, nullable=True) - sigma: Mapped[float] = mapped_column(Float, nullable=True) - alpha_init: Mapped[float] = mapped_column(Float, nullable=True) - beta_init: Mapped[float] = mapped_column(Float, nullable=True) - mu_init: Mapped[float] = mapped_column(Float, nullable=True) - sigma_init: Mapped[float] = mapped_column(Float, nullable=True) - experiment: Mapped[MultiArmedBanditDB] = relationship( - "MultiArmedBanditDB", back_populates="arms", lazy="joined" - ) - - draws: Mapped[list["MABDrawDB"]] = relationship( - "MABDrawDB", back_populates="arm", lazy="joined" - ) - - __mapper_args__ = {"polymorphic_identity": "mab_arms"} - - def to_dict(self) -> dict: - """ - Convert the ORM object to a dictionary. - """ - return { - "arm_id": self.arm_id, - "name": self.name, - "description": self.description, - "alpha": self.alpha, - "beta": self.beta, - "mu": self.mu, - "sigma": self.sigma, - "alpha_init": self.alpha_init, - "beta_init": self.beta_init, - "mu_init": self.mu_init, - "sigma_init": self.sigma_init, - "draws": [draw.to_dict() for draw in self.draws], - } - - -class MABDrawDB(DrawsBaseDB): - """ - ORM for managing draws of an experiment - """ - - __tablename__ = "mab_draws" - - draw_id: Mapped[str] = mapped_column( - ForeignKey("draws_base.draw_id", ondelete="CASCADE"), - primary_key=True, - nullable=False, - ) - - arm: Mapped[MABArmDB] = relationship( - "MABArmDB", back_populates="draws", lazy="joined" - ) - experiment: Mapped[MultiArmedBanditDB] = relationship( - "MultiArmedBanditDB", back_populates="draws", lazy="joined" - ) - - __mapper_args__ = {"polymorphic_identity": "mab_draws"} - - def to_dict(self) -> dict: - """ - Convert the ORM object to a dictionary. - """ - return { - "draw_id": self.draw_id, - "client_id": self.client_id, - "draw_datetime_utc": self.draw_datetime_utc, - "arm_id": self.arm_id, - "experiment_id": self.experiment_id, - "user_id": self.user_id, - "reward": self.reward, - "observation_type": self.observation_type, - "observed_datetime_utc": self.observed_datetime_utc, - } - - -async def save_mab_to_db( - experiment: MultiArmedBandit, - user_id: int, - workspace_id: int, - asession: AsyncSession, -) -> MultiArmedBanditDB: - """ - Save the experiment to the database. - """ - arms = [ - MABArmDB( - name=arm.name, - description=arm.description, - alpha_init=arm.alpha_init, - beta_init=arm.beta_init, - mu_init=arm.mu_init, - sigma_init=arm.sigma_init, - n_outcomes=arm.n_outcomes, - alpha=arm.alpha_init, - beta=arm.beta_init, - mu=arm.mu_init, - sigma=arm.sigma_init, - user_id=user_id, - ) - for arm in experiment.arms - ] - experiment_db = MultiArmedBanditDB( - name=experiment.name, - description=experiment.description, - user_id=user_id, - workspace_id=workspace_id, - is_active=experiment.is_active, - created_datetime_utc=datetime.now(timezone.utc), - n_trials=0, - arms=arms, - sticky_assignment=experiment.sticky_assignment, - auto_fail=experiment.auto_fail, - auto_fail_value=experiment.auto_fail_value, - auto_fail_unit=experiment.auto_fail_unit, - prior_type=experiment.prior_type.value, - reward_type=experiment.reward_type.value, - ) - - asession.add(experiment_db) - await asession.commit() - await asession.refresh(experiment_db) - - return experiment_db - - -async def get_all_mabs( - workspace_id: int, - asession: AsyncSession, -) -> Sequence[MultiArmedBanditDB]: - """ - Get all the experiments from the database for a specific workspace. - """ - statement = ( - select(MultiArmedBanditDB) - .where( - MultiArmedBanditDB.workspace_id == workspace_id, - ) - .order_by(MultiArmedBanditDB.experiment_id) - ) - - return (await asession.execute(statement)).unique().scalars().all() - - -async def get_mab_by_id( - experiment_id: int, - workspace_id: int, - asession: AsyncSession, -) -> MultiArmedBanditDB | None: - """ - Get the experiment by id from a specific workspace. - """ - conditions = [ - MultiArmedBanditDB.workspace_id == workspace_id, - MultiArmedBanditDB.experiment_id == experiment_id, - ] - - result = await asession.execute(select(MultiArmedBanditDB).where(and_(*conditions))) - - return result.unique().scalar_one_or_none() - - -async def delete_mab_by_id( - experiment_id: int, workspace_id: int, asession: AsyncSession -) -> None: - """ - Delete the experiment by id. - """ - await asession.execute( - delete(NotificationsDB).where(NotificationsDB.experiment_id == experiment_id) - ) - - await asession.execute( - delete(DrawsBaseDB).where(DrawsBaseDB.experiment_id == experiment_id) - ) - - await asession.execute( - delete(MABArmDB).where( - and_( - MABArmDB.arm_id == ArmBaseDB.arm_id, - ArmBaseDB.experiment_id == experiment_id, - ) - ) - ) - await asession.execute( - delete(MultiArmedBanditDB).where( - and_( - MultiArmedBanditDB.experiment_id == experiment_id, - MultiArmedBanditDB.experiment_id == ExperimentBaseDB.experiment_id, - MultiArmedBanditDB.workspace_id == workspace_id, - ) - ) - ) - await asession.commit() - return None - - -async def get_obs_by_experiment_arm_id( - experiment_id: int, arm_id: int, asession: AsyncSession -) -> Sequence[MABDrawDB]: - """ - Get the observations for the experiment and arm. - """ - statement = ( - select(MABDrawDB) - .where(MABDrawDB.experiment_id == experiment_id) - .where(MABDrawDB.reward.is_not(None)) - .where(MABDrawDB.arm_id == arm_id) - .order_by(MABDrawDB.observed_datetime_utc) - ) - - return (await asession.execute(statement)).unique().scalars().all() - - -async def get_all_obs_by_experiment_id( - experiment_id: int, - workspace_id: int, - asession: AsyncSession, -) -> Sequence[MABDrawDB]: - """ - Get the observations for the experiment. - """ - # First, verify experiment belongs to the workspace - experiment = await get_mab_by_id( - experiment_id=experiment_id, - workspace_id=workspace_id, - asession=asession, - ) - - if experiment is None: - # Return empty list if experiment doesn't exist or doesn't belong to workspace - return [] - - statement = ( - select(MABDrawDB) - .where(MABDrawDB.experiment_id == experiment_id) - .where(MABDrawDB.reward.is_not(None)) - .order_by(MABDrawDB.observed_datetime_utc) - ) - - return (await asession.execute(statement)).unique().scalars().all() - - -async def get_draw_by_id(draw_id: str, asession: AsyncSession) -> MABDrawDB | None: - """ - Get a draw by its ID, which should be unique across the system. - """ - statement = select(MABDrawDB).where(MABDrawDB.draw_id == draw_id) - result = await asession.execute(statement) - - return result.unique().scalar_one_or_none() - - -async def get_draw_by_client_id( - client_id: str, - experiment_id: int, - asession: AsyncSession, -) -> MABDrawDB | None: - """ - Get a draw by its client ID for a specific experiment. - """ - statement = ( - select(MABDrawDB) - .where(MABDrawDB.client_id == client_id) - .where(MABDrawDB.client_id.is_not(None)) - .where(MABDrawDB.experiment_id == experiment_id) - ) - result = await asession.execute(statement) - - return result.unique().scalars().first() - - -async def save_draw_to_db( - experiment_id: int, - arm_id: int, - draw_id: str, - client_id: str | None, - user_id: int | None, - asession: AsyncSession, - workspace_id: int | None = None, -) -> MABDrawDB: - """ - Save a draw to the database - """ - # If user_id is not provided but needed, get it from the experiment - if user_id is None and workspace_id is not None: - experiment = await get_mab_by_id( - experiment_id=experiment_id, - workspace_id=workspace_id, - asession=asession, - ) - - if experiment: - user_id = experiment.user_id - else: - raise ValueError(f"Experiment with id {experiment_id} not found") - - if user_id is None: - raise ValueError("User ID must be provided or derivable from experiment") - - draw_datetime_utc: datetime = datetime.now(timezone.utc) - - draw = MABDrawDB( - draw_id=draw_id, - client_id=client_id, - experiment_id=experiment_id, - user_id=user_id, - arm_id=arm_id, - draw_datetime_utc=draw_datetime_utc, - ) - - asession.add(draw) - await asession.commit() - await asession.refresh(draw) - - return draw - - -async def save_observation_to_db( - draw: MABDrawDB, - reward: float, - asession: AsyncSession, - observation_type: ObservationType, -) -> MABDrawDB: - """ - Save an observation to the database - """ - - draw.reward = reward - draw.observed_datetime_utc = datetime.now(timezone.utc) - draw.observation_type = observation_type - asession.add(draw) - await asession.commit() - await asession.refresh(draw) - - return draw diff --git a/backend/app/mab/observation.py b/backend/app/mab/observation.py deleted file mode 100644 index 0ef34c2..0000000 --- a/backend/app/mab/observation.py +++ /dev/null @@ -1,94 +0,0 @@ -from datetime import datetime, timezone - -from fastapi import HTTPException -from sqlalchemy.ext.asyncio import AsyncSession - -from ..schemas import ObservationType, Outcome, RewardLikelihood -from .models import ( - MABArmDB, - MABDrawDB, - MultiArmedBanditDB, - save_observation_to_db, -) -from .sampling_utils import update_arm_params -from .schemas import ( - ArmResponse, - MultiArmedBanditSample, -) - - -async def update_based_on_outcome( - experiment: MultiArmedBanditDB, - draw: MABDrawDB, - outcome: float, - asession: AsyncSession, - observation_type: ObservationType, -) -> ArmResponse: - """ - Update the arm parameters based on the outcome. - - This is a helper function to allow `auto_fail` job to call - it as well. - """ - update_experiment_metadata(experiment) - - arm = get_arm_from_experiment(experiment, draw.arm_id) - arm.n_outcomes += 1 - - experiment_data = MultiArmedBanditSample.model_validate(experiment) - await update_arm_parameters(arm, experiment_data, outcome) - await save_updated_data(arm, draw, outcome, observation_type, asession) - - return ArmResponse.model_validate(arm) - - -def update_experiment_metadata(experiment: MultiArmedBanditDB) -> None: - """Update experiment metadata with new trial information""" - experiment.n_trials += 1 - experiment.last_trial_datetime_utc = datetime.now(tz=timezone.utc) - - -def get_arm_from_experiment(experiment: MultiArmedBanditDB, arm_id: int) -> MABArmDB: - """Get and validate the arm from the experiment""" - arms = [a for a in experiment.arms if a.arm_id == arm_id] - if not arms: - raise HTTPException(status_code=404, detail=f"Arm with id {arm_id} not found") - return arms[0] - - -async def update_arm_parameters( - arm: MABArmDB, experiment_data: MultiArmedBanditSample, outcome: float -) -> None: - """Update the arm parameters based on the reward type and outcome""" - if experiment_data.reward_type == RewardLikelihood.BERNOULLI: - Outcome(outcome) # Check if reward is 0 or 1 - arm.alpha, arm.beta = update_arm_params( - ArmResponse.model_validate(arm), - experiment_data.prior_type, - experiment_data.reward_type, - outcome, - ) - elif experiment_data.reward_type == RewardLikelihood.NORMAL: - arm.mu, arm.sigma = update_arm_params( - ArmResponse.model_validate(arm), - experiment_data.prior_type, - experiment_data.reward_type, - outcome, - ) - else: - raise HTTPException( - status_code=400, - detail="Reward type not supported.", - ) - - -async def save_updated_data( - arm: MABArmDB, - draw: MABDrawDB, - outcome: float, - observation_type: ObservationType, - asession: AsyncSession, -) -> None: - """Save the updated arm and observation data""" - await asession.commit() - await save_observation_to_db(draw, outcome, asession, observation_type) diff --git a/backend/app/mab/routers.py b/backend/app/mab/routers.py deleted file mode 100644 index 9a22582..0000000 --- a/backend/app/mab/routers.py +++ /dev/null @@ -1,357 +0,0 @@ -from typing import Annotated, Optional -from uuid import uuid4 - -from fastapi import APIRouter, Depends -from fastapi.exceptions import HTTPException -from sqlalchemy.ext.asyncio import AsyncSession - -from ..auth.dependencies import ( - authenticate_workspace_key, - get_verified_user, - require_admin_role, -) -from ..database import get_async_session -from ..models import get_notifications_from_db, save_notifications_to_db -from ..schemas import NotificationsResponse, ObservationType -from ..users.models import UserDB -from ..utils import setup_logger -from ..workspaces.models import ( - WorkspaceDB, - get_user_default_workspace, -) -from .models import ( - MABDrawDB, - MultiArmedBanditDB, - delete_mab_by_id, - get_all_mabs, - get_all_obs_by_experiment_id, - get_draw_by_client_id, - get_draw_by_id, - get_mab_by_id, - save_draw_to_db, - save_mab_to_db, -) -from .observation import update_based_on_outcome -from .sampling_utils import choose_arm -from .schemas import ( - ArmResponse, - MABDrawResponse, - MABObservationResponse, - MultiArmedBandit, - MultiArmedBanditResponse, - MultiArmedBanditSample, -) - -router = APIRouter(prefix="/mab", tags=["Multi-Armed Bandits"]) - -logger = setup_logger(__name__) - - -@router.post("/", response_model=MultiArmedBanditResponse) -async def create_mab( - experiment: MultiArmedBandit, - user_db: Annotated[UserDB, Depends(require_admin_role)], - asession: AsyncSession = Depends(get_async_session), -) -> MultiArmedBanditResponse: - """ - Create a new experiment in the user's current workspace. - """ - workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - - if workspace_db is None: - raise HTTPException( - status_code=404, - detail="Workspace not found. Please create a workspace first.", - ) - - mab = await save_mab_to_db( - experiment, user_db.user_id, workspace_db.workspace_id, asession - ) - - notifications = await save_notifications_to_db( - experiment_id=mab.experiment_id, - user_id=user_db.user_id, - notifications=experiment.notifications, - asession=asession, - ) - - mab_dict = mab.to_dict() - mab_dict["notifications"] = [n.to_dict() for n in notifications] - - return MultiArmedBanditResponse.model_validate(mab_dict) - - -@router.get("/", response_model=list[MultiArmedBanditResponse]) -async def get_mabs( - user_db: Annotated[UserDB, Depends(get_verified_user)], - asession: AsyncSession = Depends(get_async_session), -) -> list[MultiArmedBanditResponse]: - """ - Get details of all experiments in the user's current workspace. - """ - workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - - if workspace_db is None: - raise HTTPException( - status_code=404, - detail="Workspace not found. Please create a workspace first.", - ) - - experiments = await get_all_mabs(workspace_db.workspace_id, asession) - - all_experiments = [] - for exp in experiments: - exp_dict = exp.to_dict() - exp_dict["notifications"] = [ - n.to_dict() - for n in await get_notifications_from_db( - exp.experiment_id, exp.user_id, asession - ) - ] - all_experiments.append( - MultiArmedBanditResponse.model_validate( - { - **exp_dict, - "notifications": [ - NotificationsResponse(**n) for n in exp_dict["notifications"] - ], - } - ) - ) - return all_experiments - - -@router.get("/{experiment_id}/", response_model=MultiArmedBanditResponse) -async def get_mab( - experiment_id: int, - user_db: Annotated[UserDB, Depends(get_verified_user)], - asession: AsyncSession = Depends(get_async_session), -) -> MultiArmedBanditResponse: - """ - Get details of experiment with the provided `experiment_id`. - """ - workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - - if workspace_db is None: - raise HTTPException( - status_code=404, - detail="Workspace not found. Please create a workspace first.", - ) - - experiment = await get_mab_by_id(experiment_id, workspace_db.workspace_id, asession) - - if experiment is None: - raise HTTPException( - status_code=404, detail=f"Experiment with id {experiment_id} not found" - ) - - experiment_dict = experiment.to_dict() - experiment_dict["notifications"] = [ - n.to_dict() - for n in await get_notifications_from_db( - experiment.experiment_id, experiment.user_id, asession - ) - ] - - return MultiArmedBanditResponse.model_validate(experiment_dict) - - -@router.delete("/{experiment_id}", response_model=dict) -async def delete_mab( - experiment_id: int, - user_db: Annotated[UserDB, Depends(require_admin_role)], - asession: AsyncSession = Depends(get_async_session), -) -> dict: - """ - Delete the experiment with the provided `experiment_id`. - """ - try: - workspace_db = await get_user_default_workspace( - asession=asession, user_db=user_db - ) - - if workspace_db is None: - raise HTTPException( - status_code=404, - detail="Workspace not found. Please create a workspace first.", - ) - - experiment = await get_mab_by_id( - experiment_id, workspace_db.workspace_id, asession - ) - if experiment is None: - raise HTTPException( - status_code=404, detail=f"Experiment with id {experiment_id} not found" - ) - await delete_mab_by_id(experiment_id, workspace_db.workspace_id, asession) - return {"message": f"Experiment with id {experiment_id} deleted successfully."} - except Exception as e: - raise HTTPException(status_code=500, detail=f"Error: {e}") from e - - -@router.get("/{experiment_id}/draw", response_model=MABDrawResponse) -async def draw_arm( - experiment_id: int, - draw_id: Optional[str] = None, - client_id: Optional[str] = None, - workspace_db: WorkspaceDB = Depends(authenticate_workspace_key), - asession: AsyncSession = Depends(get_async_session), -) -> MABDrawResponse: - """ - Draw an arm for the provided experiment. - """ - # Get workspace from user context - workspace_id = workspace_db.workspace_id - - experiment = await get_mab_by_id(experiment_id, workspace_id, asession) - if experiment is None: - raise HTTPException( - status_code=404, detail=f"Experiment with id {experiment_id} not found" - ) - - if experiment.sticky_assignment and client_id is None: - raise HTTPException( - status_code=400, - detail="Client ID is required for sticky assignment.", - ) - - # Check for existing draws - if draw_id is None: - draw_id = str(uuid4()) - - existing_draw = await get_draw_by_id(draw_id, asession) - if existing_draw: - raise HTTPException( - status_code=400, - detail=f"Draw ID {draw_id} already exists.", - ) - - experiment_data = MultiArmedBanditSample.model_validate(experiment) - chosen_arm = choose_arm(experiment=experiment_data) - chosen_arm_id = experiment.arms[chosen_arm].arm_id - - # If sticky assignment, check if the client_id has a previous arm assigned - if experiment.sticky_assignment and client_id: - previous_draw = await get_draw_by_client_id( - client_id=client_id, - experiment_id=experiment.experiment_id, - asession=asession, - ) - if previous_draw: - print(f"Previous draw found: {previous_draw.arm_id}") - chosen_arm_id = previous_draw.arm_id - - try: - _ = await save_draw_to_db( - experiment_id=experiment.experiment_id, - arm_id=chosen_arm_id, - draw_id=draw_id, - client_id=client_id, - user_id=None, - asession=asession, - workspace_id=workspace_id, - ) - except Exception as e: - raise HTTPException( - status_code=500, - detail=f"Error saving draw to database: {e}", - ) from e - - return MABDrawResponse.model_validate( - { - "draw_id": draw_id, - "client_id": client_id, - "arm": ArmResponse.model_validate( - [arm for arm in experiment.arms if arm.arm_id == chosen_arm_id][0] - ), - } - ) - - -@router.put("/{experiment_id}/{draw_id}/{outcome}", response_model=ArmResponse) -async def update_arm( - experiment_id: int, - draw_id: str, - outcome: float, - workspace_db: WorkspaceDB = Depends(authenticate_workspace_key), - asession: AsyncSession = Depends(get_async_session), -) -> ArmResponse: - """ - Update the arm with the provided `arm_id` for the given - `experiment_id` based on the `outcome`. - """ - # Get workspace from user context - workspace_id = workspace_db.workspace_id - - experiment, draw = await validate_experiment_and_draw( - experiment_id, draw_id, workspace_id, asession - ) - - return await update_based_on_outcome( - experiment, draw, outcome, asession, ObservationType.USER - ) - - -@router.get( - "/{experiment_id}/outcomes", - response_model=list[MABObservationResponse], -) -async def get_outcomes( - experiment_id: int, - workspace_db: WorkspaceDB = Depends(authenticate_workspace_key), - asession: AsyncSession = Depends(get_async_session), -) -> list[MABObservationResponse]: - """ - Get the outcomes for the experiment. - """ - # Get workspace from user context - workspace_id = workspace_db.workspace_id - - experiment = await get_mab_by_id(experiment_id, workspace_id, asession) - if not experiment: - raise HTTPException( - status_code=404, detail=f"Experiment with id {experiment_id} not found" - ) - - rewards = await get_all_obs_by_experiment_id( - experiment_id=experiment.experiment_id, - workspace_id=workspace_id, - asession=asession, - ) - - return [MABObservationResponse.model_validate(reward) for reward in rewards] - - -async def validate_experiment_and_draw( - experiment_id: int, - draw_id: str, - workspace_id: int, - asession: AsyncSession, -) -> tuple[MultiArmedBanditDB, MABDrawDB]: - """Validate the experiment and draw""" - experiment = await get_mab_by_id(experiment_id, workspace_id, asession) - if experiment is None: - raise HTTPException( - status_code=404, detail=f"Experiment with id {experiment_id} not found" - ) - - draw = await get_draw_by_id(draw_id=draw_id, asession=asession) - if draw is None: - raise HTTPException(status_code=404, detail=f"Draw with id {draw_id} not found") - - if draw.experiment_id != experiment_id: - raise HTTPException( - status_code=400, - detail=( - f"Draw with id {draw_id} does not belong " - f"to experiment with id {experiment_id}", - ), - ) - - if draw.reward is not None: - raise HTTPException( - status_code=400, - detail=f"Draw with id {draw_id} already has an outcome.", - ) - - return experiment, draw diff --git a/backend/app/mab/sampling_utils.py b/backend/app/mab/sampling_utils.py deleted file mode 100644 index 6bc5fe8..0000000 --- a/backend/app/mab/sampling_utils.py +++ /dev/null @@ -1,138 +0,0 @@ -import numpy as np -from numpy.random import beta, normal - -from ..mab.schemas import ArmResponse, MultiArmedBanditSample -from ..schemas import ArmPriors, Outcome, RewardLikelihood - - -def sample_beta_binomial(alphas: np.ndarray, betas: np.ndarray) -> int: - """ - Thompson Sampling with Beta-Binomial distribution. - - Parameters - ---------- - alphas : alpha parameter of Beta distribution for each arm - betas : beta parameter of Beta distribution for each arm - """ - samples = beta(alphas, betas) - return int(samples.argmax()) - - -def sample_normal(mus: np.ndarray, sigmas: np.ndarray) -> int: - """ - Thompson Sampling with conjugate normal distribution. - - Parameters - ---------- - mus: mean of Normal distribution for each arm - sigmas: standard deviation of Normal distribution for each arm - """ - samples = normal(loc=mus, scale=sigmas) - return int(samples.argmax()) - - -def update_arm_beta_binomial( - alpha: float, beta: float, reward: Outcome -) -> tuple[float, float]: - """ - Update the alpha and beta parameters of the Beta distribution. - - Parameters - ---------- - alpha : int - The alpha parameter of the Beta distribution. - beta : int - The beta parameter of the Beta distribution. - reward : Outcome - The reward of the arm. - """ - if reward == Outcome.SUCCESS: - - return alpha + 1, beta - else: - return alpha, beta + 1 - - -def update_arm_normal( - current_mu: float, current_sigma: float, reward: float, sigma_llhood: float -) -> tuple[float, float]: - """ - Update the mean and standard deviation of the Normal distribution. - - Parameters - ---------- - current_mu : The mean of the Normal distribution. - current_sigma : The standard deviation of the Normal distribution. - reward : The reward of the arm. - sigma_llhood : The likelihood of the standard deviation. - """ - denom = sigma_llhood**2 + current_sigma**2 - new_sigma = sigma_llhood * current_sigma / np.sqrt(denom) - new_mu = (current_mu * sigma_llhood**2 + reward * current_sigma**2) / denom - return new_mu, new_sigma - - -def choose_arm(experiment: MultiArmedBanditSample) -> int: - """ - Choose arm based on posterior - - Parameters - ---------- - experiment : MultiArmedBanditResponse - The experiment data containing priors and rewards for each arm. - """ - if (experiment.prior_type == ArmPriors.BETA) and ( - experiment.reward_type == RewardLikelihood.BERNOULLI - ): - alphas = np.array([arm.alpha for arm in experiment.arms]) - betas = np.array([arm.beta for arm in experiment.arms]) - - return sample_beta_binomial(alphas=alphas, betas=betas) - - elif (experiment.prior_type == ArmPriors.NORMAL) and ( - experiment.reward_type == RewardLikelihood.NORMAL - ): - mus = np.array([arm.mu for arm in experiment.arms]) - sigmas = np.array([arm.sigma for arm in experiment.arms]) - # TODO: add support for non-std sigma_llhood - return sample_normal(mus=mus, sigmas=sigmas) - else: - raise ValueError("Prior and reward type combination is not supported.") - - -def update_arm_params( - arm: ArmResponse, - prior_type: ArmPriors, - reward_type: RewardLikelihood, - reward: float, -) -> tuple: - """ - Update the arm with the provided `arm_id` based on the `reward`. - - Parameters - ---------- - arm: The arm to update. - prior_type: The type of prior distribution for the arms. - reward_type: The likelihood distribution of the reward. - reward: The reward of the arm. - """ - - if (prior_type == ArmPriors.BETA) and (reward_type == RewardLikelihood.BERNOULLI): - if arm.alpha is None or arm.beta is None: - raise ValueError("Beta prior requires alpha and beta.") - outcome = Outcome(reward) - return update_arm_beta_binomial(alpha=arm.alpha, beta=arm.beta, reward=outcome) - - elif ( - (prior_type == ArmPriors.NORMAL) - and (reward_type == RewardLikelihood.NORMAL) - and (arm.mu and arm.sigma) - ): - return update_arm_normal( - current_mu=arm.mu, - current_sigma=arm.sigma, - reward=reward, - sigma_llhood=1.0, # TODO: add support for non-std sigma_llhood - ) - else: - raise ValueError("Prior and reward type combination is not supported.") diff --git a/backend/app/mab/schemas.py b/backend/app/mab/schemas.py deleted file mode 100644 index 60fbf0e..0000000 --- a/backend/app/mab/schemas.py +++ /dev/null @@ -1,262 +0,0 @@ -from datetime import datetime -from typing import Optional, Self - -from pydantic import BaseModel, ConfigDict, Field, model_validator - -from ..schemas import ( - ArmPriors, - AutoFailUnitType, - Notifications, - NotificationsResponse, - RewardLikelihood, - allowed_combos_mab, -) - - -class Arm(BaseModel): - """ - Pydantic model for a arm of the experiment. - """ - - name: str = Field( - max_length=150, - examples=["Arm 1"], - ) - description: str = Field( - max_length=500, - examples=["This is a description of the arm."], - ) - - # prior variables - alpha_init: Optional[float] = Field( - default=None, examples=[None, 1.0], description="Alpha parameter for Beta prior" - ) - beta_init: Optional[float] = Field( - default=None, examples=[None, 1.0], description="Beta parameter for Beta prior" - ) - mu_init: Optional[float] = Field( - default=None, - examples=[None, 0.0], - description="Mean parameter for Normal prior", - ) - sigma_init: Optional[float] = Field( - default=None, - examples=[None, 1.0], - description="Standard deviation parameter for Normal prior", - ) - n_outcomes: Optional[int] = Field( - default=0, - description="Number of outcomes for the arm", - examples=[0, 10, 15], - ) - - @model_validator(mode="after") - def check_values(self) -> Self: - """ - Check if the values are unique. - """ - alpha = self.alpha_init - beta = self.beta_init - sigma = self.sigma_init - if alpha is not None and alpha <= 0: - raise ValueError("Alpha must be greater than 0.") - if beta is not None and beta <= 0: - raise ValueError("Beta must be greater than 0.") - if sigma is not None and sigma <= 0: - raise ValueError("Sigma must be greater than 0.") - return self - - -class ArmResponse(Arm): - """ - Pydantic model for an response for arm creation - """ - - arm_id: int - alpha: Optional[float] - beta: Optional[float] - mu: Optional[float] - sigma: Optional[float] - model_config = ConfigDict( - from_attributes=True, - ) - - -class MultiArmedBanditBase(BaseModel): - """ - Pydantic model for an experiment - Base model. - Note: Do not use this model directly. Use `MultiArmedBandit` instead. - """ - - name: str = Field( - max_length=150, - examples=["Experiment 1"], - ) - - description: str = Field( - max_length=500, - examples=["This is a description of the experiment."], - ) - - sticky_assignment: bool = Field( - description="Whether the arm assignment is sticky or not.", - default=False, - ) - - auto_fail: bool = Field( - description=( - "Whether the experiment should fail automatically after " - "a certain period if no outcome is registered." - ), - default=False, - ) - - auto_fail_value: Optional[int] = Field( - description="The time period after which the experiment should fail.", - default=None, - ) - - auto_fail_unit: Optional[AutoFailUnitType] = Field( - description="The time unit for the auto fail period.", - default=None, - ) - - reward_type: RewardLikelihood = Field( - description="The type of reward we observe from the experiment.", - default=RewardLikelihood.BERNOULLI, - ) - prior_type: ArmPriors = Field( - description="The type of prior distribution for the arms.", - default=ArmPriors.BETA, - ) - - is_active: bool = True - - model_config = ConfigDict(from_attributes=True) - - -class MultiArmedBandit(MultiArmedBanditBase): - """ - Pydantic model for an experiment. - """ - - arms: list[Arm] - notifications: Notifications - - @model_validator(mode="after") - def auto_fail_unit_and_value_set(self) -> Self: - """ - Validate that the auto fail unit and value are set if auto fail is True. - """ - if self.auto_fail: - if ( - not self.auto_fail_value - or not self.auto_fail_unit - or self.auto_fail_value <= 0 - ): - raise ValueError( - ( - "Auto fail is enabled. " - "Please provide both auto_fail_value and auto_fail_unit." - ) - ) - return self - - @model_validator(mode="after") - def arms_at_least_two(self) -> Self: - """ - Validate that the experiment has at least two arms. - """ - if len(self.arms) < 2: - raise ValueError("The experiment must have at least two arms.") - return self - - @model_validator(mode="after") - def check_prior_reward_type_combo(self) -> Self: - """ - Validate that the prior and reward type combination is allowed. - """ - - if (self.prior_type, self.reward_type) not in allowed_combos_mab: - raise ValueError("Prior and reward type combo not supported.") - return self - - @model_validator(mode="after") - def check_arm_missing_params(self) -> Self: - """ - Check if the arm reward type is same as the experiment reward type. - """ - prior_type = self.prior_type - arms = self.arms - - prior_params = { - ArmPriors.BETA: ("alpha_init", "beta_init"), - ArmPriors.NORMAL: ("mu_init", "sigma_init"), - } - - for arm in arms: - arm_dict = arm.model_dump() - if prior_type in prior_params: - missing_params = [] - for param in prior_params[prior_type]: - if param not in arm_dict.keys(): - missing_params.append(param) - elif arm_dict[param] is None: - missing_params.append(param) - - if missing_params: - val = prior_type.value - raise ValueError(f"{val} prior needs {','.join(missing_params)}.") - return self - - model_config = ConfigDict(from_attributes=True) - - -class MultiArmedBanditResponse(MultiArmedBanditBase): - """ - Pydantic model for an response for experiment creation. - Returns the id of the experiment and the arms - """ - - experiment_id: int - workspace_id: int - arms: list[ArmResponse] - notifications: list[NotificationsResponse] - created_datetime_utc: datetime - last_trial_datetime_utc: Optional[datetime] = None - n_trials: int - model_config = ConfigDict(from_attributes=True, revalidate_instances="always") - - -class MultiArmedBanditSample(MultiArmedBanditBase): - """ - Pydantic model for an experiment sample. - """ - - experiment_id: int - arms: list[ArmResponse] - - -class MABObservationResponse(BaseModel): - """ - Pydantic model for binary observations of the experiment. - """ - - experiment_id: int - arm_id: int - reward: float - draw_id: str - client_id: str | None - observed_datetime_utc: datetime - - model_config = ConfigDict(from_attributes=True) - - -class MABDrawResponse(BaseModel): - """ - Pydantic model for the response of the draw endpoint. - """ - - draw_id: str - client_id: str | None - arm: ArmResponse diff --git a/backend/migrations/versions/157dd34c4cd4_debugging_nullable_variable.py b/backend/migrations/versions/157dd34c4cd4_debugging_nullable_variable.py deleted file mode 100644 index a46adf5..0000000 --- a/backend/migrations/versions/157dd34c4cd4_debugging_nullable_variable.py +++ /dev/null @@ -1,34 +0,0 @@ -"""debugging nullable variable - -Revision ID: 157dd34c4cd4 -Revises: 4c06937ee88f -Create Date: 2025-05-30 15:43:35.254416 - -""" - -from typing import Sequence, Union - -import sqlalchemy as sa -from alembic import op - -# revision identifiers, used by Alembic. -revision: str = "157dd34c4cd4" -down_revision: Union[str, None] = "4c06937ee88f" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.alter_column( - "draws", "client_id", existing_type=sa.VARCHAR(length=36), nullable=True - ) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.alter_column( - "draws", "client_id", existing_type=sa.VARCHAR(length=36), nullable=False - ) - # ### end Alembic commands ### diff --git a/backend/migrations/versions/4c06937ee88f_update_models_for_treatment_arm_and_.py b/backend/migrations/versions/4c06937ee88f_update_models_for_treatment_arm_and_.py deleted file mode 100644 index 4a6f7d6..0000000 --- a/backend/migrations/versions/4c06937ee88f_update_models_for_treatment_arm_and_.py +++ /dev/null @@ -1,64 +0,0 @@ -"""update models for treatment arm and debugging - -Revision ID: 4c06937ee88f -Revises: 57173e1aa8ae -Create Date: 2025-05-30 12:14:04.889301 - -""" - -from typing import Sequence, Union - -import sqlalchemy as sa -from alembic import op - -# revision identifiers, used by Alembic. -revision: str = "4c06937ee88f" -down_revision: Union[str, None] = "57173e1aa8ae" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.add_column("arms", sa.Column("is_treatment_arm", sa.Boolean(), nullable=True)) - op.drop_constraint("arms_user_id_fkey", "arms", type_="foreignkey") - op.drop_column("arms", "user_id") - op.drop_constraint("clients_user_id_fkey", "clients", type_="foreignkey") - op.drop_column("clients", "user_id") - op.drop_constraint("context_user_id_fkey", "context", type_="foreignkey") - op.drop_column("context", "user_id") - op.drop_constraint("draws_user_id_fkey", "draws", type_="foreignkey") - op.drop_column("draws", "user_id") - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.add_column( - "draws", sa.Column("user_id", sa.INTEGER(), autoincrement=False, nullable=False) - ) - op.create_foreign_key( - "draws_user_id_fkey", "draws", "users", ["user_id"], ["user_id"] - ) - op.add_column( - "context", - sa.Column("user_id", sa.INTEGER(), autoincrement=False, nullable=False), - ) - op.create_foreign_key( - "context_user_id_fkey", "context", "users", ["user_id"], ["user_id"] - ) - op.add_column( - "clients", - sa.Column("user_id", sa.INTEGER(), autoincrement=False, nullable=False), - ) - op.create_foreign_key( - "clients_user_id_fkey", "clients", "users", ["user_id"], ["user_id"] - ) - op.add_column( - "arms", sa.Column("user_id", sa.INTEGER(), autoincrement=False, nullable=False) - ) - op.create_foreign_key( - "arms_user_id_fkey", "arms", "users", ["user_id"], ["user_id"] - ) - op.drop_column("arms", "is_treatment_arm") - # ### end Alembic commands ### diff --git a/backend/migrations/versions/57173e1aa8ae_update_tables_with_workspace_id.py b/backend/migrations/versions/57173e1aa8ae_update_tables_with_workspace_id.py deleted file mode 100644 index 19bc51e..0000000 --- a/backend/migrations/versions/57173e1aa8ae_update_tables_with_workspace_id.py +++ /dev/null @@ -1,131 +0,0 @@ -"""update tables with workspace id - -Revision ID: 57173e1aa8ae -Revises: 2d3946caceff -Create Date: 2025-05-27 21:10:55.499461 - -""" - -from typing import Sequence, Union - -import sqlalchemy as sa -from alembic import op -from sqlalchemy.dialects import postgresql - -# revision identifiers, used by Alembic. -revision: str = "57173e1aa8ae" -down_revision: Union[str, None] = "2d3946caceff" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_table("contextual_arms") - op.drop_table("contexts") - op.drop_table("contextual_mabs") - op.drop_table("contextual_draws") - op.add_column("context", sa.Column("workspace_id", sa.Integer(), nullable=False)) - op.create_foreign_key( - None, "context", "workspace", ["workspace_id"], ["workspace_id"] - ) - op.add_column("draws", sa.Column("workspace_id", sa.Integer(), nullable=False)) - op.create_foreign_key( - None, "draws", "workspace", ["workspace_id"], ["workspace_id"] - ) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_constraint(None, "draws", type_="foreignkey") - op.drop_column("draws", "workspace_id") - op.drop_constraint(None, "context", type_="foreignkey") - op.drop_column("context", "workspace_id") - op.create_table( - "contextual_draws", - sa.Column("draw_id", sa.VARCHAR(), autoincrement=False, nullable=False), - sa.Column( - "context_val", - postgresql.ARRAY(sa.DOUBLE_PRECISION(precision=53)), - autoincrement=False, - nullable=False, - ), - sa.ForeignKeyConstraint( - ["draw_id"], - ["draws_base.draw_id"], - name="contextual_draws_draw_id_fkey", - ondelete="CASCADE", - ), - sa.PrimaryKeyConstraint("draw_id", name="contextual_draws_pkey"), - ) - op.create_table( - "contextual_mabs", - sa.Column("experiment_id", sa.INTEGER(), autoincrement=False, nullable=False), - sa.ForeignKeyConstraint( - ["experiment_id"], - ["experiments_base.experiment_id"], - name="contextual_mabs_experiment_id_fkey", - ondelete="CASCADE", - ), - sa.PrimaryKeyConstraint("experiment_id", name="contextual_mabs_pkey"), - postgresql_ignore_search_path=False, - ) - op.create_table( - "contexts", - sa.Column("context_id", sa.INTEGER(), autoincrement=True, nullable=False), - sa.Column("experiment_id", sa.INTEGER(), autoincrement=False, nullable=False), - sa.Column("user_id", sa.INTEGER(), autoincrement=False, nullable=False), - sa.Column("name", sa.VARCHAR(length=150), autoincrement=False, nullable=False), - sa.Column( - "description", sa.VARCHAR(length=500), autoincrement=False, nullable=True - ), - sa.Column( - "value_type", sa.VARCHAR(length=50), autoincrement=False, nullable=False - ), - sa.ForeignKeyConstraint( - ["experiment_id"], - ["contextual_mabs.experiment_id"], - name="contexts_experiment_id_fkey", - ), - sa.ForeignKeyConstraint( - ["user_id"], ["users.user_id"], name="contexts_user_id_fkey" - ), - sa.PrimaryKeyConstraint("context_id", name="contexts_pkey"), - ) - op.create_table( - "contextual_arms", - sa.Column("arm_id", sa.INTEGER(), autoincrement=False, nullable=False), - sa.Column( - "mu_init", - sa.DOUBLE_PRECISION(precision=53), - autoincrement=False, - nullable=False, - ), - sa.Column( - "sigma_init", - sa.DOUBLE_PRECISION(precision=53), - autoincrement=False, - nullable=False, - ), - sa.Column( - "mu", - postgresql.ARRAY(sa.DOUBLE_PRECISION(precision=53)), - autoincrement=False, - nullable=False, - ), - sa.Column( - "covariance", - postgresql.ARRAY(sa.DOUBLE_PRECISION(precision=53)), - autoincrement=False, - nullable=False, - ), - sa.ForeignKeyConstraint( - ["arm_id"], - ["arms_base.arm_id"], - name="contextual_arms_arm_id_fkey", - ondelete="CASCADE", - ), - sa.PrimaryKeyConstraint("arm_id", name="contextual_arms_pkey"), - ) - # ### end Alembic commands ### diff --git a/backend/migrations/versions/2d3946caceff_new_start.py b/backend/migrations/versions/6101ba814d91_fresh_start.py similarity index 76% rename from backend/migrations/versions/2d3946caceff_new_start.py rename to backend/migrations/versions/6101ba814d91_fresh_start.py index 820aa7c..d246310 100644 --- a/backend/migrations/versions/2d3946caceff_new_start.py +++ b/backend/migrations/versions/6101ba814d91_fresh_start.py @@ -1,8 +1,8 @@ -"""new start +"""fresh start -Revision ID: 2d3946caceff +Revision ID: 6101ba814d91 Revises: -Create Date: 2025-05-27 18:39:15.282285 +Create Date: 2025-06-03 18:00:18.919218 """ @@ -13,7 +13,7 @@ from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. -revision: str = "2d3946caceff" +revision: str = "6101ba814d91" down_revision: Union[str, None] = None branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -201,7 +201,6 @@ def upgrade() -> None: op.create_table( "arms", sa.Column("arm_id", sa.Integer(), nullable=False), - sa.Column("user_id", sa.Integer(), nullable=False), sa.Column("workspace_id", sa.Integer(), nullable=False), sa.Column("experiment_id", sa.Integer(), nullable=False), sa.Column("name", sa.String(length=150), nullable=False), @@ -211,6 +210,7 @@ def upgrade() -> None: sa.Column("sigma_init", sa.Float(), nullable=True), sa.Column("mu", postgresql.ARRAY(sa.Float()), nullable=True), sa.Column("covariance", postgresql.ARRAY(sa.Float()), nullable=True), + sa.Column("is_treatment_arm", sa.Boolean(), nullable=True), sa.Column("alpha_init", sa.Float(), nullable=True), sa.Column("beta_init", sa.Float(), nullable=True), sa.Column("alpha", sa.Float(), nullable=True), @@ -219,10 +219,6 @@ def upgrade() -> None: ["experiment_id"], ["experiments.experiment_id"], ), - sa.ForeignKeyConstraint( - ["user_id"], - ["users.user_id"], - ), sa.ForeignKeyConstraint( ["workspace_id"], ["workspace.workspace_id"], @@ -248,28 +244,15 @@ def upgrade() -> None: ), sa.PrimaryKeyConstraint("arm_id"), ) - op.create_table( - "bayes_ab_experiments", - sa.Column("experiment_id", sa.Integer(), nullable=False), - sa.ForeignKeyConstraint( - ["experiment_id"], ["experiments_base.experiment_id"], ondelete="CASCADE" - ), - sa.PrimaryKeyConstraint("experiment_id"), - ) op.create_table( "clients", sa.Column("client_id", sa.String(), nullable=False), - sa.Column("user_id", sa.Integer(), nullable=False), sa.Column("experiment_id", sa.Integer(), nullable=False), sa.Column("workspace_id", sa.Integer(), nullable=False), sa.ForeignKeyConstraint( ["experiment_id"], ["experiments.experiment_id"], ), - sa.ForeignKeyConstraint( - ["user_id"], - ["users.user_id"], - ), sa.ForeignKeyConstraint( ["workspace_id"], ["workspace.workspace_id"], @@ -280,7 +263,7 @@ def upgrade() -> None: "context", sa.Column("context_id", sa.Integer(), nullable=False), sa.Column("experiment_id", sa.Integer(), nullable=False), - sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("workspace_id", sa.Integer(), nullable=False), sa.Column("name", sa.String(length=150), nullable=False), sa.Column("description", sa.String(length=500), nullable=True), sa.Column("value_type", sa.String(length=50), nullable=False), @@ -289,19 +272,11 @@ def upgrade() -> None: ["experiments.experiment_id"], ), sa.ForeignKeyConstraint( - ["user_id"], - ["users.user_id"], + ["workspace_id"], + ["workspace.workspace_id"], ), sa.PrimaryKeyConstraint("context_id"), ) - op.create_table( - "contextual_mabs", - sa.Column("experiment_id", sa.Integer(), nullable=False), - sa.ForeignKeyConstraint( - ["experiment_id"], ["experiments_base.experiment_id"], ondelete="CASCADE" - ), - sa.PrimaryKeyConstraint("experiment_id"), - ) op.create_table( "event_messages", sa.Column("message_id", sa.Integer(), nullable=False), @@ -315,14 +290,6 @@ def upgrade() -> None: ), sa.PrimaryKeyConstraint("message_id"), ) - op.create_table( - "mabs", - sa.Column("experiment_id", sa.Integer(), nullable=False), - sa.ForeignKeyConstraint( - ["experiment_id"], ["experiments_base.experiment_id"], ondelete="CASCADE" - ), - sa.PrimaryKeyConstraint("experiment_id"), - ) op.create_table( "notifications", sa.Column("notification_id", sa.Integer(), nullable=False), @@ -382,52 +349,13 @@ def upgrade() -> None: ), sa.PrimaryKeyConstraint("notification_id"), ) - op.create_table( - "bayes_ab_arms", - sa.Column("arm_id", sa.Integer(), nullable=False), - sa.Column("mu_init", sa.Float(), nullable=False), - sa.Column("sigma_init", sa.Float(), nullable=False), - sa.Column("mu", sa.Float(), nullable=False), - sa.Column("sigma", sa.Float(), nullable=False), - sa.Column("is_treatment_arm", sa.Boolean(), nullable=False), - sa.ForeignKeyConstraint(["arm_id"], ["arms_base.arm_id"], ondelete="CASCADE"), - sa.PrimaryKeyConstraint("arm_id"), - ) - op.create_table( - "contexts", - sa.Column("context_id", sa.Integer(), nullable=False), - sa.Column("experiment_id", sa.Integer(), nullable=False), - sa.Column("user_id", sa.Integer(), nullable=False), - sa.Column("name", sa.String(length=150), nullable=False), - sa.Column("description", sa.String(length=500), nullable=True), - sa.Column("value_type", sa.String(length=50), nullable=False), - sa.ForeignKeyConstraint( - ["experiment_id"], - ["contextual_mabs.experiment_id"], - ), - sa.ForeignKeyConstraint( - ["user_id"], - ["users.user_id"], - ), - sa.PrimaryKeyConstraint("context_id"), - ) - op.create_table( - "contextual_arms", - sa.Column("arm_id", sa.Integer(), nullable=False), - sa.Column("mu_init", sa.Float(), nullable=False), - sa.Column("sigma_init", sa.Float(), nullable=False), - sa.Column("mu", postgresql.ARRAY(sa.Float()), nullable=False), - sa.Column("covariance", postgresql.ARRAY(sa.Float()), nullable=False), - sa.ForeignKeyConstraint(["arm_id"], ["arms_base.arm_id"], ondelete="CASCADE"), - sa.PrimaryKeyConstraint("arm_id"), - ) op.create_table( "draws", sa.Column("draw_id", sa.String(), nullable=False), sa.Column("arm_id", sa.Integer(), nullable=False), sa.Column("experiment_id", sa.Integer(), nullable=False), - sa.Column("user_id", sa.Integer(), nullable=False), - sa.Column("client_id", sa.String(length=36), nullable=False), + sa.Column("workspace_id", sa.Integer(), nullable=False), + sa.Column("client_id", sa.String(length=36), nullable=True), sa.Column("draw_datetime_utc", sa.DateTime(timezone=True), nullable=False), sa.Column("observed_datetime_utc", sa.DateTime(timezone=True), nullable=True), sa.Column( @@ -450,8 +378,8 @@ def upgrade() -> None: ["experiments.experiment_id"], ), sa.ForeignKeyConstraint( - ["user_id"], - ["users.user_id"], + ["workspace_id"], + ["workspace.workspace_id"], ), sa.PrimaryKeyConstraint("draw_id"), ) @@ -485,67 +413,18 @@ def upgrade() -> None: ), sa.PrimaryKeyConstraint("draw_id"), ) - op.create_table( - "mab_arms", - sa.Column("arm_id", sa.Integer(), nullable=False), - sa.Column("alpha", sa.Float(), nullable=True), - sa.Column("beta", sa.Float(), nullable=True), - sa.Column("mu", sa.Float(), nullable=True), - sa.Column("sigma", sa.Float(), nullable=True), - sa.Column("alpha_init", sa.Float(), nullable=True), - sa.Column("beta_init", sa.Float(), nullable=True), - sa.Column("mu_init", sa.Float(), nullable=True), - sa.Column("sigma_init", sa.Float(), nullable=True), - sa.ForeignKeyConstraint(["arm_id"], ["arms_base.arm_id"], ondelete="CASCADE"), - sa.PrimaryKeyConstraint("arm_id"), - ) - op.create_table( - "bayes_ab_draws", - sa.Column("draw_id", sa.String(), nullable=False), - sa.ForeignKeyConstraint( - ["draw_id"], ["draws_base.draw_id"], ondelete="CASCADE" - ), - sa.PrimaryKeyConstraint("draw_id"), - ) - op.create_table( - "contextual_draws", - sa.Column("draw_id", sa.String(), nullable=False), - sa.Column("context_val", postgresql.ARRAY(sa.Float()), nullable=False), - sa.ForeignKeyConstraint( - ["draw_id"], ["draws_base.draw_id"], ondelete="CASCADE" - ), - sa.PrimaryKeyConstraint("draw_id"), - ) - op.create_table( - "mab_draws", - sa.Column("draw_id", sa.String(), nullable=False), - sa.ForeignKeyConstraint( - ["draw_id"], ["draws_base.draw_id"], ondelete="CASCADE" - ), - sa.PrimaryKeyConstraint("draw_id"), - ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_table("mab_draws") - op.drop_table("contextual_draws") - op.drop_table("bayes_ab_draws") - op.drop_table("mab_arms") op.drop_table("draws_base") op.drop_table("draws") - op.drop_table("contextual_arms") - op.drop_table("contexts") - op.drop_table("bayes_ab_arms") op.drop_table("notifications_db") op.drop_table("notifications") - op.drop_table("mabs") op.drop_table("event_messages") - op.drop_table("contextual_mabs") op.drop_table("context") op.drop_table("clients") - op.drop_table("bayes_ab_experiments") op.drop_table("arms_base") op.drop_table("arms") op.drop_table("user_workspace") diff --git a/backend/migrations/versions/9f7482ba882f_workspace_model.py b/backend/migrations/versions/9f7482ba882f_workspace_model.py deleted file mode 100644 index 3543211..0000000 --- a/backend/migrations/versions/9f7482ba882f_workspace_model.py +++ /dev/null @@ -1,123 +0,0 @@ -"""Workspace model - -Revision ID: 9f7482ba882f -Revises: 275ff74c0866 -Create Date: 2025-05-04 11:56:03.939578 - -""" - -from typing import Sequence, Union - -import sqlalchemy as sa -from alembic import op - -# revision identifiers, used by Alembic. -revision: str = "9f7482ba882f" -down_revision: Union[str, None] = "275ff74c0866" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table( - "workspace", - sa.Column("api_daily_quota", sa.Integer(), nullable=True), - sa.Column("api_key_first_characters", sa.String(length=5), nullable=True), - sa.Column( - "api_key_updated_datetime_utc", sa.DateTime(timezone=True), nullable=True - ), - sa.Column("api_key_rotated_by_user_id", sa.Integer(), nullable=True), - sa.Column("content_quota", sa.Integer(), nullable=True), - sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False), - sa.Column("hashed_api_key", sa.String(length=96), nullable=True), - sa.Column("updated_datetime_utc", sa.DateTime(timezone=True), nullable=False), - sa.Column("workspace_id", sa.Integer(), nullable=False), - sa.Column("workspace_name", sa.String(), nullable=False), - sa.Column("is_default", sa.Boolean(), nullable=False), - sa.ForeignKeyConstraint( - ["api_key_rotated_by_user_id"], - ["users.user_id"], - ), - sa.PrimaryKeyConstraint("workspace_id"), - sa.UniqueConstraint("hashed_api_key"), - sa.UniqueConstraint("workspace_name"), - ) - op.create_table( - "api_key_rotation_history", - sa.Column("rotation_id", sa.Integer(), nullable=False), - sa.Column("workspace_id", sa.Integer(), nullable=False), - sa.Column("rotated_by_user_id", sa.Integer(), nullable=False), - sa.Column("key_first_characters", sa.String(length=5), nullable=False), - sa.Column("rotation_datetime_utc", sa.DateTime(timezone=True), nullable=False), - sa.ForeignKeyConstraint( - ["rotated_by_user_id"], - ["users.user_id"], - ), - sa.ForeignKeyConstraint( - ["workspace_id"], ["workspace.workspace_id"], ondelete="CASCADE" - ), - sa.PrimaryKeyConstraint("rotation_id"), - ) - op.create_table( - "pending_invitations", - sa.Column("invitation_id", sa.Integer(), nullable=False), - sa.Column("email", sa.String(), nullable=False), - sa.Column("workspace_id", sa.Integer(), nullable=False), - sa.Column( - "role", - sa.Enum("ADMIN", "READ_ONLY", name="userroles", native_enum=False), - nullable=False, - ), - sa.Column("inviter_id", sa.Integer(), nullable=False), - sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False), - sa.ForeignKeyConstraint( - ["inviter_id"], - ["users.user_id"], - ), - sa.ForeignKeyConstraint( - ["workspace_id"], ["workspace.workspace_id"], ondelete="CASCADE" - ), - sa.PrimaryKeyConstraint("invitation_id"), - ) - op.create_table( - "user_workspace", - sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False), - sa.Column( - "default_workspace", - sa.Boolean(), - server_default=sa.text("false"), - nullable=False, - ), - sa.Column("updated_datetime_utc", sa.DateTime(timezone=True), nullable=False), - sa.Column("user_id", sa.Integer(), nullable=False), - sa.Column( - "user_role", - sa.Enum("ADMIN", "READ_ONLY", name="userroles", native_enum=False), - nullable=False, - ), - sa.Column("workspace_id", sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(["user_id"], ["users.user_id"], ondelete="CASCADE"), - sa.ForeignKeyConstraint( - ["workspace_id"], ["workspace.workspace_id"], ondelete="CASCADE" - ), - sa.PrimaryKeyConstraint("user_id", "workspace_id"), - ) - op.add_column( - "experiments_base", sa.Column("workspace_id", sa.Integer(), nullable=False) - ) - op.create_foreign_key( - None, "experiments_base", "workspace", ["workspace_id"], ["workspace_id"] - ) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_constraint(None, "experiments_base", type_="foreignkey") - op.drop_column("experiments_base", "workspace_id") - op.drop_table("user_workspace") - op.drop_table("pending_invitations") - op.drop_table("api_key_rotation_history") - op.drop_table("workspace") - # ### end Alembic commands ### diff --git a/backend/migrations/versions/ecddd830b464_remove_user_api_key.py b/backend/migrations/versions/ecddd830b464_remove_user_api_key.py deleted file mode 100644 index b03b032..0000000 --- a/backend/migrations/versions/ecddd830b464_remove_user_api_key.py +++ /dev/null @@ -1,70 +0,0 @@ -"""Remove User API key - -Revision ID: ecddd830b464 -Revises: 9f7482ba882f -Create Date: 2025-05-21 13:59:22.199884 - -""" - -from typing import Sequence, Union - -import sqlalchemy as sa -from alembic import op -from sqlalchemy.dialects import postgresql - -# revision identifiers, used by Alembic. -revision: str = "ecddd830b464" -down_revision: Union[str, None] = "9f7482ba882f" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_constraint("users_hashed_api_key_key", "users", type_="unique") - op.drop_column("users", "api_daily_quota") - op.drop_column("users", "hashed_api_key") - op.drop_column("users", "api_key_updated_datetime_utc") - op.drop_column("users", "api_key_first_characters") - op.drop_column("users", "experiments_quota") - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.add_column( - "users", - sa.Column( - "experiments_quota", sa.INTEGER(), autoincrement=False, nullable=True - ), - ) - op.add_column( - "users", - sa.Column( - "api_key_first_characters", - sa.VARCHAR(length=5), - autoincrement=False, - nullable=False, - ), - ) - op.add_column( - "users", - sa.Column( - "api_key_updated_datetime_utc", - postgresql.TIMESTAMP(timezone=True), - autoincrement=False, - nullable=False, - ), - ) - op.add_column( - "users", - sa.Column( - "hashed_api_key", sa.VARCHAR(length=96), autoincrement=False, nullable=False - ), - ) - op.add_column( - "users", - sa.Column("api_daily_quota", sa.INTEGER(), autoincrement=False, nullable=True), - ) - op.create_unique_constraint("users_hashed_api_key_key", "users", ["hashed_api_key"]) - # ### end Alembic commands ### From 40e6c779614a5c19013902957670799271daaf79 Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Wed, 4 Jun 2025 13:17:17 +0300 Subject: [PATCH 51/74] debugging --- backend/app/experiments/dependencies.py | 61 ++++++++++++----------- backend/app/experiments/routers.py | 1 + backend/app/experiments/sampling_utils.py | 15 +++--- 3 files changed, 40 insertions(+), 37 deletions(-) diff --git a/backend/app/experiments/dependencies.py b/backend/app/experiments/dependencies.py index f9c7380..249a989 100644 --- a/backend/app/experiments/dependencies.py +++ b/backend/app/experiments/dependencies.py @@ -98,7 +98,11 @@ async def validate_experiment_and_draw( async def format_rewards_for_arm_update( - experiment: ExperimentDB, chosen_arm_id: int, reward: float, asession: AsyncSession + experiment: ExperimentDB, + chosen_arm_id: int, + reward: float, + context_val: Union[list[float], None], + asession: AsyncSession, ) -> tuple[list[float], list[list[float]] | None, list[float] | None]: """ Format the rewards for the arm update. @@ -106,43 +110,42 @@ async def format_rewards_for_arm_update( previous_rewards = await get_draws_with_rewards_by_experiment_id( experiment_id=experiment.experiment_id, asession=asession ) - if not previous_rewards: - return [], None, None rewards = [] treatments = None contexts = None - if experiment.exp_type != ExperimentsEnum.BAYESAB.value: - rewards = [ - draw.reward for draw in previous_rewards if draw.arm_id == chosen_arm_id - ] - else: - treatments = [] - for draw in previous_rewards: - rewards.append(draw.reward) - treatments.append( - [ - float(arm.is_treatment_arm) - for arm in experiment.arms - if arm.arm_id == draw.arm_id - ][0] - ) - - if experiment.exp_type == ExperimentsEnum.CMAB.value: - contexts = [] - for draw in previous_rewards: - if draw.context_val: - contexts.append(draw.context_val) - else: - raise ValueError( - f"Context value is missing for draw id {draw.draw_id}" - f" in CMAB experiment {draw.experiment_id}." + if previous_rewards: + if experiment.exp_type != ExperimentsEnum.BAYESAB.value: + rewards = [ + draw.reward for draw in previous_rewards if draw.arm_id == chosen_arm_id + ] + else: + treatments = [] + for draw in previous_rewards: + rewards.append(draw.reward) + treatments.append( + [ + float(arm.is_treatment_arm) + for arm in experiment.arms + if arm.arm_id == draw.arm_id + ][0] ) + if experiment.exp_type == ExperimentsEnum.CMAB.value: + contexts = [] + for draw in previous_rewards: + if draw.context_val: + contexts.append(draw.context_val) + else: + raise ValueError( + f"Context value is missing for draw id {draw.draw_id}" + f" in CMAB experiment {draw.experiment_id}." + ) + rewards_list = [reward] if rewards is None else [reward] + rewards - context_list = None if not draw.context_val else [draw.context_val] + context_list = None if not context_val else [context_val] if contexts and context_list: context_list = context_list + contexts diff --git a/backend/app/experiments/routers.py b/backend/app/experiments/routers.py index 27f885d..58431ad 100644 --- a/backend/app/experiments/routers.py +++ b/backend/app/experiments/routers.py @@ -419,6 +419,7 @@ async def update_experiment_arm( experiment=experiment, chosen_arm_id=draw.arm_id, reward=reward, + context_val=draw.context_val, asession=asession, ) diff --git a/backend/app/experiments/sampling_utils.py b/backend/app/experiments/sampling_utils.py index 31e518d..ac9db72 100644 --- a/backend/app/experiments/sampling_utils.py +++ b/backend/app/experiments/sampling_utils.py @@ -83,7 +83,7 @@ def _update_arm_normal( current_covariance: np.ndarray, reward: float, llhood_sigma: float, - context: Optional[np.ndarray] = None, + context: np.ndarray, ) -> tuple[float, np.ndarray]: """ Update the mean and standard deviation of the Normal distribution. @@ -98,8 +98,7 @@ def _update_arm_normal( """ # Likelihood covariance matrix inverse llhood_covariance_inv = np.eye(len(current_mu)) / llhood_sigma**2 - if context is not None: - llhood_covariance_inv *= context.T @ context + llhood_covariance_inv *= context.T @ context # Prior covariance matrix inverse prior_covariance_inv = np.linalg.inv(current_covariance) @@ -109,9 +108,9 @@ def _update_arm_normal( # New mean llhood_term: Union[np.ndarray, float] = reward / llhood_sigma**2 - print("llhood_term", llhood_term) if context is not None: llhood_term = (context * llhood_term).squeeze() + new_mu = new_covariance @ ((prior_covariance_inv @ current_mu) + llhood_term) return new_mu.tolist(), new_covariance.tolist() @@ -138,6 +137,7 @@ def _update_arm_laplace( reward_likelihood : The likelihood function of the reward. prior_type : The prior type of the arm. """ + print(current_mu.shape, current_covariance.shape, reward.shape, context.shape) def objective(theta: np.ndarray) -> float: """ @@ -257,9 +257,8 @@ def update_arm( ] + [1.0] ) - context = ( - np.zeros((len(experiment.arms), 3)) if not context else np.array(context) - ) + context = np.zeros((len(rewards), 3)) if not context else np.array(context) + print(rewards, treatments) context[:, 0] = np.array(treatments) context[:, 1] = 1.0 - np.array(treatments) context[:, 2] = 1.0 @@ -302,7 +301,7 @@ def update_arm( arm.mu and arm.covariance ), "Arm must have mu and covariance parameters." if context is None: - context = np.ones_like(arm.mu) + context = np.ones((1, len(arm.mu))) # Normal likelihood if experiment.reward_type == RewardLikelihood.NORMAL: return _update_arm_normal( From 6ac9a311035fd8cb829723326f92d51a764905a0 Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Thu, 5 Jun 2025 18:14:38 +0300 Subject: [PATCH 52/74] fix messages test --- backend/app/messages/models.py | 2 +- ...392_fix_messages_foreign_key_constraint.py | 41 +++++++++++++++++++ backend/tests/test_messages.py | 13 +++--- 3 files changed, 50 insertions(+), 6 deletions(-) create mode 100644 backend/migrations/versions/45b9483ee392_fix_messages_foreign_key_constraint.py diff --git a/backend/app/messages/models.py b/backend/app/messages/models.py index 28ec3fb..61b557b 100644 --- a/backend/app/messages/models.py +++ b/backend/app/messages/models.py @@ -102,7 +102,7 @@ class EventMessageDB(MessageDB): nullable=False, ) experiment_id: Mapped[int] = mapped_column( - Integer, ForeignKey("experiments_base.experiment_id"), nullable=False + Integer, ForeignKey("experiments.experiment_id"), nullable=False ) __mapper_args__ = {"polymorphic_identity": "event"} diff --git a/backend/migrations/versions/45b9483ee392_fix_messages_foreign_key_constraint.py b/backend/migrations/versions/45b9483ee392_fix_messages_foreign_key_constraint.py new file mode 100644 index 0000000..cdfcd4c --- /dev/null +++ b/backend/migrations/versions/45b9483ee392_fix_messages_foreign_key_constraint.py @@ -0,0 +1,41 @@ +"""fix messages foreign key constraint + +Revision ID: 45b9483ee392 +Revises: 6101ba814d91 +Create Date: 2025-06-05 18:10:33.744331 + +""" + +from typing import Sequence, Union + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "45b9483ee392" +down_revision: Union[str, None] = "6101ba814d91" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint( + op.f("event_messages_experiment_id_fkey"), "event_messages", type_="foreignkey" + ) + op.create_foreign_key( + None, "event_messages", "experiments", ["experiment_id"], ["experiment_id"] + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, "event_messages", type_="foreignkey") + op.create_foreign_key( + op.f("event_messages_experiment_id_fkey"), + "event_messages", + "experiments_base", + ["experiment_id"], + ["experiment_id"], + ) + # ### end Alembic commands ### diff --git a/backend/tests/test_messages.py b/backend/tests/test_messages.py index 86396b4..253e4f2 100644 --- a/backend/tests/test_messages.py +++ b/backend/tests/test_messages.py @@ -9,19 +9,20 @@ base_mab_payload = { "name": "Test", - "description": "Test description", + "description": "Test description.", + "exp_type": "mab", "prior_type": "beta", "reward_type": "binary", "arms": [ { "name": "arm 1", - "description": "arm 1 description", + "description": "arm 1 description.", "alpha_init": 5, "beta_init": 1, }, { "name": "arm 2", - "description": "arm 2 description", + "description": "arm 2 description.", "alpha_init": 1, "beta_init": 4, }, @@ -34,6 +35,8 @@ "onPercentBetter": False, "percentBetterThreshold": 5, }, + "contexts": [], + "clients": [], } @@ -53,13 +56,13 @@ def admin_token(client: TestClient) -> str: @fixture def experiment_id(client: TestClient, admin_token: str) -> Generator[int, None, None]: response = client.post( - "/mab", + "/experiment", headers={"Authorization": f"Bearer {admin_token}"}, json=base_mab_payload, ) yield response.json()["experiment_id"] client.delete( - f"/mab/{response.json()['experiment_id']}", + f"/experiment/id/{response.json()['experiment_id']}", headers={"Authorization": f"Bearer {admin_token}"}, ) From 741245b93e5a92ace156b02dbe00aeb30606b7ad Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Thu, 5 Jun 2025 18:51:31 +0300 Subject: [PATCH 53/74] fix notifications and tests --- backend/app/experiments/sampling_utils.py | 2 - backend/app/models.py | 243 +--------------------- backend/app/schemas.py | 180 ---------------- backend/jobs/create_notifications.py | 18 +- backend/tests/test_notifications_job.py | 83 ++++---- 5 files changed, 52 insertions(+), 474 deletions(-) delete mode 100644 backend/app/schemas.py diff --git a/backend/app/experiments/sampling_utils.py b/backend/app/experiments/sampling_utils.py index ac9db72..7ae8eca 100644 --- a/backend/app/experiments/sampling_utils.py +++ b/backend/app/experiments/sampling_utils.py @@ -137,7 +137,6 @@ def _update_arm_laplace( reward_likelihood : The likelihood function of the reward. prior_type : The prior type of the arm. """ - print(current_mu.shape, current_covariance.shape, reward.shape, context.shape) def objective(theta: np.ndarray) -> float: """ @@ -258,7 +257,6 @@ def update_arm( + [1.0] ) context = np.zeros((len(rewards), 3)) if not context else np.array(context) - print(rewards, treatments) context[:, 0] = np.array(treatments) context[:, 1] = 1.0 - np.array(treatments) context[:, 2] = 1.0 diff --git a/backend/app/models.py b/backend/app/models.py index 77833b5..b8df8a0 100644 --- a/backend/app/models.py +++ b/backend/app/models.py @@ -1,21 +1,6 @@ -import uuid -from datetime import datetime -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING -from sqlalchemy import ( - Boolean, - DateTime, - Enum, - Float, - ForeignKey, - Integer, - String, - select, -) -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column - -from .schemas import AutoFailUnitType, EventType, Notifications, ObservationType +from sqlalchemy.orm import DeclarativeBase if TYPE_CHECKING: pass @@ -25,227 +10,3 @@ class Base(DeclarativeBase): """Base class for SQLAlchemy models""" pass - - -class ExperimentBaseDB(Base): - """ - Base model for experiments. - """ - - __tablename__ = "experiments_base" - - experiment_id: Mapped[int] = mapped_column( - Integer, primary_key=True, nullable=False - ) - name: Mapped[str] = mapped_column(String(length=150), nullable=False) - description: Mapped[str] = mapped_column(String(length=500), nullable=False) - sticky_assignment: Mapped[bool] = mapped_column( - Boolean, nullable=False, default=False - ) - auto_fail: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) - auto_fail_value: Mapped[int] = mapped_column(Integer, nullable=True) - auto_fail_unit: Mapped[AutoFailUnitType] = mapped_column( - Enum(AutoFailUnitType), nullable=True - ) - - user_id: Mapped[int] = mapped_column( - Integer, ForeignKey("users.user_id"), nullable=False - ) - workspace_id: Mapped[int] = mapped_column( - Integer, ForeignKey("workspace.workspace_id"), nullable=False - ) - is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) - exp_type: Mapped[str] = mapped_column(String(length=50), nullable=False) - prior_type: Mapped[str] = mapped_column(String(length=50), nullable=False) - reward_type: Mapped[str] = mapped_column(String(length=50), nullable=False) - - created_datetime_utc: Mapped[datetime] = mapped_column( - DateTime(timezone=True), nullable=False - ) - n_trials: Mapped[int] = mapped_column(Integer, nullable=False) - last_trial_datetime_utc: Mapped[datetime] = mapped_column( - DateTime(timezone=True), nullable=True - ) - - __mapper_args__ = { - "polymorphic_identity": "experiment", - "polymorphic_on": "exp_type", - } - - def __repr__(self) -> str: - """ - String representation of the model - """ - return f"" - - -class ArmBaseDB(Base): - """ - Base model for arms. - """ - - __tablename__ = "arms_base" - - arm_id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False) - experiment_id: Mapped[int] = mapped_column( - Integer, ForeignKey("experiments_base.experiment_id"), nullable=False - ) - user_id: Mapped[int] = mapped_column( - Integer, ForeignKey("users.user_id"), nullable=False - ) - - name: Mapped[str] = mapped_column(String(length=150), nullable=False) - description: Mapped[str] = mapped_column(String(length=500), nullable=False) - arm_type: Mapped[str] = mapped_column(String(length=50), nullable=False) - n_outcomes: Mapped[int] = mapped_column(Integer, nullable=False, default=0) - - __mapper_args__ = { - "polymorphic_identity": "arm", - "polymorphic_on": "arm_type", - } - - -class DrawsBaseDB(Base): - """ - Base model for draws. - """ - - __tablename__ = "draws_base" - - draw_id: Mapped[str] = mapped_column( - String, primary_key=True, default=lambda x: str(uuid.uuid4()) - ) - - client_id: Mapped[str] = mapped_column(String, nullable=True) - - arm_id: Mapped[int] = mapped_column( - Integer, ForeignKey("arms_base.arm_id"), nullable=False - ) - experiment_id: Mapped[int] = mapped_column( - Integer, ForeignKey("experiments_base.experiment_id"), nullable=False - ) - user_id: Mapped[int] = mapped_column( - Integer, ForeignKey("users.user_id"), nullable=False - ) - - draw_datetime_utc: Mapped[datetime] = mapped_column( - DateTime(timezone=True), - nullable=False, - ) - - observed_datetime_utc: Mapped[datetime] = mapped_column( - DateTime(timezone=True), nullable=True - ) - - observation_type: Mapped[ObservationType] = mapped_column( - Enum(ObservationType), nullable=True - ) - - draw_type: Mapped[str] = mapped_column(String(length=50), nullable=False) - - reward: Mapped[float] = mapped_column(Float, nullable=True) - - __mapper_args__ = { - "polymorphic_identity": "draw", - "polymorphic_on": "draw_type", - } - - -class NotificationsDB(Base): - """ - Model for notifications. - Note: if you are updating this, you should also update models in - the background celery job - """ - - __tablename__ = "notifications_db" - - notification_id: Mapped[int] = mapped_column( - Integer, primary_key=True, nullable=False - ) - experiment_id: Mapped[int] = mapped_column( - Integer, ForeignKey("experiments_base.experiment_id"), nullable=False - ) - user_id: Mapped[int] = mapped_column( - Integer, ForeignKey("users.user_id"), nullable=False - ) - notification_type: Mapped[EventType] = mapped_column( - Enum(EventType), nullable=False - ) - notification_value: Mapped[int] = mapped_column(Integer, nullable=False) - is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) - - def to_dict(self) -> dict: - """ - Convert the model to a dictionary - """ - return { - "notification_id": self.notification_id, - "experiment_id": self.experiment_id, - "user_id": self.user_id, - "notification_type": self.notification_type, - "notification_value": self.notification_value, - "is_active": self.is_active, - } - - -async def save_notifications_to_db( - experiment_id: int, - user_id: int, - notifications: Notifications, - asession: AsyncSession, -) -> list[NotificationsDB]: - """ - Save notifications to the database - """ - notification_records = [] - - if notifications.onTrialCompletion: - notification_row = NotificationsDB( - experiment_id=experiment_id, - user_id=user_id, - notification_type=EventType.TRIALS_COMPLETED, - notification_value=notifications.numberOfTrials, - is_active=True, - ) - notification_records.append(notification_row) - - if notifications.onDaysElapsed: - notification_row = NotificationsDB( - experiment_id=experiment_id, - user_id=user_id, - notification_type=EventType.DAYS_ELAPSED, - notification_value=notifications.daysElapsed, - is_active=True, - ) - notification_records.append(notification_row) - - if notifications.onPercentBetter: - notification_row = NotificationsDB( - experiment_id=experiment_id, - user_id=user_id, - notification_type=EventType.PERCENTAGE_BETTER, - notification_value=notifications.percentBetterThreshold, - is_active=True, - ) - notification_records.append(notification_row) - - asession.add_all(notification_records) - await asession.commit() - - return notification_records - - -async def get_notifications_from_db( - experiment_id: int, user_id: int, asession: AsyncSession -) -> Sequence[NotificationsDB]: - """ - Get notifications from the database - """ - statement = ( - select(NotificationsDB) - .where(NotificationsDB.experiment_id == experiment_id) - .where(NotificationsDB.user_id == user_id) - ) - - return (await asession.execute(statement)).scalars().all() diff --git a/backend/app/schemas.py b/backend/app/schemas.py deleted file mode 100644 index 783c52e..0000000 --- a/backend/app/schemas.py +++ /dev/null @@ -1,180 +0,0 @@ -from enum import Enum, StrEnum -from typing import Any, Self - -import numpy as np -from pydantic import BaseModel, ConfigDict, model_validator -from pydantic.types import NonNegativeInt - - -class EventType(StrEnum): - """Types of events that can trigger a notification""" - - DAYS_ELAPSED = "days_elapsed" - TRIALS_COMPLETED = "trials_completed" - PERCENTAGE_BETTER = "percentage_better" - - -class ObservationType(StrEnum): - """Types of observations that can be made""" - - USER = "user" # Generated by the user - AUTO = "auto" # Generated by the system - - -class AutoFailUnitType(StrEnum): - """Types of units for auto fail""" - - DAYS = "days" - HOURS = "hours" - - -class Notifications(BaseModel): - """ - Pydantic model for a notifications. - """ - - onTrialCompletion: bool = False - numberOfTrials: NonNegativeInt | None - onDaysElapsed: bool = False - daysElapsed: NonNegativeInt | None - onPercentBetter: bool = False - percentBetterThreshold: NonNegativeInt | None - - @model_validator(mode="after") - def validate_has_assocatiated_value(self) -> Self: - """ - Validate that the required corresponding fields have been set. - """ - if self.onTrialCompletion and ( - not self.numberOfTrials or self.numberOfTrials == 0 - ): - raise ValueError( - "numberOfTrials is required when onTrialCompletion is True" - ) - if self.onDaysElapsed and (not self.daysElapsed or self.daysElapsed == 0): - raise ValueError("daysElapsed is required when onDaysElapsed is True") - if self.onPercentBetter and ( - not self.percentBetterThreshold or self.percentBetterThreshold == 0 - ): - raise ValueError( - "percentBetterThreshold is required when onPercentBetter is True" - ) - - return self - - -class NotificationsResponse(BaseModel): - """ - Pydantic model for a response for notifications - """ - - model_config = ConfigDict(from_attributes=True) - - notification_id: int - notification_type: EventType - notification_value: NonNegativeInt - is_active: bool - - -class Outcome(float, Enum): - """ - Enum for the outcome of a trial. - """ - - SUCCESS = 1 - FAILURE = 0 - - -class ArmPriors(StrEnum): - """ - Enum for the prior distribution of the arm. - """ - - BETA = "beta" - NORMAL = "normal" - - def __call__(self, theta: np.ndarray, **kwargs: Any) -> np.ndarray: - """ - Return the log pdf of the input param. - """ - if self == ArmPriors.BETA: - alpha = kwargs.get("alpha", np.ones_like(theta)) - beta = kwargs.get("beta", np.ones_like(theta)) - return (alpha - 1) * np.log(theta) + (beta - 1) * np.log(1 - theta) - - elif self == ArmPriors.NORMAL: - mu = kwargs.get("mu", np.zeros_like(theta)) - covariance = kwargs.get("covariance", np.diag(np.ones_like(theta))) - inv_cov = np.linalg.inv(covariance) - x = theta - mu - return -0.5 * x @ inv_cov @ x - - -class RewardLikelihood(StrEnum): - """ - Enum for the likelihood distribution of the reward. - """ - - BERNOULLI = "binary" - NORMAL = "real-valued" - - def __call__(self, reward: np.ndarray, probs: np.ndarray) -> np.ndarray: - """ - Calculate the log likelihood of the reward. - - Parameters - ---------- - reward : The reward. - probs : The probability of the reward. - """ - if self == RewardLikelihood.NORMAL: - return -0.5 * np.sum((reward - probs) ** 2) - elif self == RewardLikelihood.BERNOULLI: - return np.sum(reward * np.log(probs) + (1 - reward) * np.log(1 - probs)) - - -class ContextType(StrEnum): - """ - Enum for the type of context. - """ - - BINARY = "binary" - REAL_VALUED = "real-valued" - - -class ContextLinkFunctions(StrEnum): - """ - Enum for the link function of the arm params and context. - """ - - NONE = "none" - LOGISTIC = "logistic" - - def __call__(self, x: np.ndarray) -> np.ndarray: - """ - Apply the link function to the input param. - - Parameters - ---------- - x : The input param. - """ - if self == ContextLinkFunctions.NONE: - return x - elif self == ContextLinkFunctions.LOGISTIC: - return 1.0 / (1.0 + np.exp(-x)) - - -allowed_combos_mab = [ - (ArmPriors.BETA, RewardLikelihood.BERNOULLI), - (ArmPriors.NORMAL, RewardLikelihood.NORMAL), -] - -allowed_combos_cmab = [ - (ArmPriors.NORMAL, RewardLikelihood.BERNOULLI), - (ArmPriors.NORMAL, RewardLikelihood.NORMAL), -] - -allowed_combos_bayes_ab = [ - (ArmPriors.NORMAL, RewardLikelihood.BERNOULLI), - (ArmPriors.NORMAL, RewardLikelihood.NORMAL), -] diff --git a/backend/jobs/create_notifications.py b/backend/jobs/create_notifications.py index bb50508..f35adee 100644 --- a/backend/jobs/create_notifications.py +++ b/backend/jobs/create_notifications.py @@ -16,9 +16,9 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.database import get_async_session +from app.experiments.models import ExperimentDB, NotificationsDB +from app.experiments.schemas import EventType from app.messages.models import EventMessageDB -from app.models import ExperimentBaseDB, NotificationsDB -from app.schemas import EventType from app.utils import setup_logger logger = setup_logger(log_level=logging.INFO) @@ -34,10 +34,10 @@ async def check_days_elapsed( Check if the number of days elapsed since the experiment was created is greater than or equal to the milestone """ - experiments_stmt = select(ExperimentBaseDB).where( - ExperimentBaseDB.experiment_id == experiment_id + experiments_stmt = select(ExperimentDB).where( + ExperimentDB.experiment_id == experiment_id ) - experiment: ExperimentBaseDB | None = ( + experiment: ExperimentDB | None = ( (await asession.execute(experiments_stmt)).scalars().first() ) @@ -100,12 +100,8 @@ async def check_trials_completed( or equal to the milestone. """ # Fetch experiment - stmt = select(ExperimentBaseDB).where( - ExperimentBaseDB.experiment_id == experiment_id - ) - experiment: ExperimentBaseDB | None = ( - (await asession.execute(stmt)).scalars().first() - ) + stmt = select(ExperimentDB).where(ExperimentDB.experiment_id == experiment_id) + experiment: ExperimentDB | None = (await asession.execute(stmt)).scalars().first() if experiment: if experiment.n_trials >= milestone_trials: diff --git a/backend/tests/test_notifications_job.py b/backend/tests/test_notifications_job.py index 8b1ef80..911cbfc 100644 --- a/backend/tests/test_notifications_job.py +++ b/backend/tests/test_notifications_job.py @@ -12,21 +12,22 @@ from backend.jobs import create_notifications from backend.jobs.create_notifications import process_notifications -base_mab_payload = { +base_experiment_payload = { "name": "Test", - "description": "Test description", + "description": "Test description.", + "exp_type": "mab", "prior_type": "beta", "reward_type": "binary", "arms": [ { "name": "arm 1", - "description": "arm 1 description", + "description": "arm 1 description.", "alpha_init": 5, "beta_init": 1, }, { "name": "arm 2", - "description": "arm 2 description", + "description": "arm 2 description.", "alpha_init": 1, "beta_init": 4, }, @@ -39,6 +40,8 @@ "onPercentBetter": False, "percentBetterThreshold": 5, }, + "contexts": [], + "clients": [], } @@ -67,65 +70,65 @@ def admin_token(client: TestClient) -> str: class TestNotificationsJob: @fixture - def create_mabs_days_elapsed( + def create_experiments_days_elapsed( self, client: TestClient, admin_token: str, request: FixtureRequest ) -> Generator: - mabs = [] - n_mabs, days_elapsed = request.param + experiments = [] + n_experiments, days_elapsed = request.param - payload: dict = copy.deepcopy(base_mab_payload) + payload: dict = copy.deepcopy(base_experiment_payload) payload["notifications"]["onDaysElapsed"] = True payload["notifications"]["daysElapsed"] = days_elapsed - for _ in range(n_mabs): + for _ in range(n_experiments): response = client.post( - "/mab", + "/experiment", json=payload, headers={"Authorization": f"Bearer {admin_token}"}, ) - mabs.append(response.json()) - yield mabs - for mab in mabs: + experiments.append(response.json()) + yield experiments + for experiment in experiments: client.delete( - f"/mab/{mab['experiment_id']}", + f"/experiment/id/{experiment['experiment_id']}", headers={"Authorization": f"Bearer {admin_token}"}, ) @fixture - def create_mabs_trials_run( + def create_experiments_trials_run( self, client: TestClient, admin_token: str, request: FixtureRequest ) -> Generator: - mabs = [] - n_mabs, n_trials = request.param + experiments = [] + n_experiments, n_trials = request.param - payload: dict = copy.deepcopy(base_mab_payload) + payload: dict = copy.deepcopy(base_experiment_payload) payload["notifications"]["onTrialCompletion"] = True payload["notifications"]["numberOfTrials"] = n_trials - for _ in range(n_mabs): + for _ in range(n_experiments): response = client.post( - "/mab", + "/experiment", json=payload, headers={"Authorization": f"Bearer {admin_token}"}, ) - mabs.append(response.json()) - yield mabs - for mab in mabs: + experiments.append(response.json()) + yield experiments + for experiment in experiments: client.delete( - f"/mab/{mab['experiment_id']}", + f"/experiment/id/{experiment['experiment_id']}", headers={"Authorization": f"Bearer {admin_token}"}, ) @mark.parametrize( - "create_mabs_days_elapsed, days_elapsed", + "create_experiments_days_elapsed, days_elapsed", [((3, 4), 4), ((4, 62), 64), ((3, 40), 40)], - indirect=["create_mabs_days_elapsed"], + indirect=["create_experiments_days_elapsed"], ) async def test_days_elapsed_notification( self, client: TestClient, admin_token: str, - create_mabs_days_elapsed: list[dict], + create_experiments_days_elapsed: list[dict], db_session: Session, days_elapsed: int, monkeypatch: MonkeyPatch, @@ -137,18 +140,18 @@ async def test_days_elapsed_notification( fake_datetime(days_elapsed), ) n_processed = await process_notifications(asession) - assert n_processed == len(create_mabs_days_elapsed) + assert n_processed == len(create_experiments_days_elapsed) @mark.parametrize( - "create_mabs_days_elapsed, days_elapsed", + "create_experiments_days_elapsed, days_elapsed", [((3, 4), 3), ((4, 62), 50), ((3, 40), 0)], - indirect=["create_mabs_days_elapsed"], + indirect=["create_experiments_days_elapsed"], ) async def test_days_elapsed_notification_not_sent( self, client: TestClient, admin_token: str, - create_mabs_days_elapsed: list[dict], + create_experiments_days_elapsed: list[dict], db_session: Session, days_elapsed: int, monkeypatch: MonkeyPatch, @@ -163,16 +166,16 @@ async def test_days_elapsed_notification_not_sent( assert n_processed == 0 @mark.parametrize( - "create_mabs_trials_run, n_trials", + "create_experiments_trials_run, n_trials", [((3, 4), 4), ((4, 62), 64), ((3, 40), 40)], - indirect=["create_mabs_trials_run"], + indirect=["create_experiments_trials_run"], ) async def test_trials_run_notification( self, client: TestClient, admin_token: str, n_trials: int, - create_mabs_trials_run: list[dict], + create_experiments_trials_run: list[dict], db_session: Session, asession: AsyncSession, workspace_api_key: str, @@ -180,11 +183,11 @@ async def test_trials_run_notification( n_processed = await process_notifications(asession) assert n_processed == 0 headers = {"Authorization": f"Bearer {workspace_api_key}"} - for mab in create_mabs_trials_run: + for experiment in create_experiments_trials_run: for i in range(n_trials): - draw_id = f"draw_{i}_{mab['experiment_id']}" - response = client.get( - f"/mab/{mab['experiment_id']}/draw", + draw_id = f"draw_{i}_{experiment['experiment_id']}" + response = client.put( + f"/experiment/{experiment['experiment_id']}/draw", params={"draw_id": draw_id}, headers=headers, ) @@ -192,10 +195,10 @@ async def test_trials_run_notification( assert response.json()["draw_id"] == draw_id response = client.put( - f"/mab/{mab['experiment_id']}/{draw_id}/1", + f"/experiment/{experiment['experiment_id']}/{draw_id}/1", headers=headers, ) assert response.status_code == 200 n_processed = await process_notifications(asession) await asyncio.sleep(0.1) - assert n_processed == len(create_mabs_trials_run) + assert n_processed == len(create_experiments_trials_run) From 8eea2db0bd266b01dc1b8c2ef06d2de51315028a Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Thu, 5 Jun 2025 21:34:08 +0300 Subject: [PATCH 54/74] debug autofail and fix corresponding tests --- backend/app/experiments/dependencies.py | 10 + backend/app/experiments/routers.py | 13 +- backend/jobs/auto_fail.py | 14 +- backend/tests/test_auto_fail.py | 284 +---- frontend/package-lock.json | 1496 +++++++++++++---------- frontend/tsconfig.json | 24 +- 6 files changed, 923 insertions(+), 918 deletions(-) diff --git a/backend/app/experiments/dependencies.py b/backend/app/experiments/dependencies.py index 249a989..75f6632 100644 --- a/backend/app/experiments/dependencies.py +++ b/backend/app/experiments/dependencies.py @@ -166,6 +166,8 @@ async def update_arm_based_on_outcome( rewards: list[float], contexts: Union[list[list[float]], None], treatments: Union[list[float], None], + observation_type: ObservationType, + asession: AsyncSession, ) -> ArmResponse: """ Update the arm parameters based on the outcome. @@ -191,6 +193,14 @@ async def update_arm_based_on_outcome( treatments=treatments, ) + await save_updated_data( + arm=experiment.arms[chosen_arm], + draw=draw, + reward=rewards[0], + observation_type=observation_type, + asession=asession, + ) + return ArmResponse.model_validate(arm) diff --git a/backend/app/experiments/routers.py b/backend/app/experiments/routers.py index 58431ad..feaba03 100644 --- a/backend/app/experiments/routers.py +++ b/backend/app/experiments/routers.py @@ -21,7 +21,6 @@ from .dependencies import ( experiments_db_to_schema, format_rewards_for_arm_update, - save_updated_data, update_arm_based_on_outcome, validate_experiment_and_draw, ) @@ -45,6 +44,7 @@ Experiment, ExperimentSample, ExperimentsEnum, + ObservationType, Outcome, ) @@ -431,17 +431,10 @@ async def update_experiment_arm( rewards=rewards_list, contexts=context_list, treatments=treatments_list, - ) - - observation_type = draw.observation_type - - await save_updated_data( - arm=experiment.arms[chosen_arm_index], - draw=draw, - reward=reward, - observation_type=observation_type, + observation_type=ObservationType.USER, asession=asession, ) + return ArmResponse.model_validate(experiment.arms[chosen_arm_index]) except Exception as e: raise HTTPException( diff --git a/backend/jobs/auto_fail.py b/backend/jobs/auto_fail.py index 2323a56..ebf8231 100644 --- a/backend/jobs/auto_fail.py +++ b/backend/jobs/auto_fail.py @@ -78,15 +78,17 @@ async def auto_fail_experiment(asession: AsyncSession) -> int: rewards_list, context_list, treatments_list = ( await format_rewards_for_arm_update( - experiment, draw.arm_id, 0.0, asession + experiment, draw.arm_id, 0.0, draw.context_val, asession ) ) await update_arm_based_on_outcome( - experiment, - draw, - rewards_list, - context_list, - treatments_list, + experiment=experiment, + draw=draw, + rewards=rewards_list, + contexts=context_list, + treatments=treatments_list, + observation_type=ObservationType.AUTO, + asession=asession, ) total_failed += 1 diff --git a/backend/tests/test_auto_fail.py b/backend/tests/test_auto_fail.py index 6a91def..e8ae0b5 100644 --- a/backend/tests/test_auto_fail.py +++ b/backend/tests/test_auto_fail.py @@ -6,14 +6,13 @@ from pytest import FixtureRequest, MonkeyPatch, fixture, mark from sqlalchemy.ext.asyncio import AsyncSession -from backend.app.bayes_ab import models as bayes_ab_models -from backend.app.contextual_mab import models as cmab_models -from backend.app.mab import models as mab_models -from backend.jobs.auto_fail import auto_fail_bayes_ab, auto_fail_cmab, auto_fail_mab +from backend.app.experiments import models +from backend.jobs.auto_fail import auto_fail_experiment -base_mab_payload = { +base_experiment_payload = { "name": "Test AUTO FAIL", "description": "Test AUTO FAIL description", + "exp_type": "mab", "prior_type": "beta", "reward_type": "binary", "auto_fail": True, @@ -41,84 +40,8 @@ "onPercentBetter": False, "percentBetterThreshold": 5, }, -} - -base_cmab_payload = { - "name": "Test", - "description": "Test description", - "prior_type": "normal", - "reward_type": "real-valued", - "auto_fail": True, - "auto_fail_value": 3, - "auto_fail_unit": "hours", - "arms": [ - { - "name": "arm 1", - "description": "arm 1 description", - "mu_init": 0, - "sigma_init": 1, - }, - { - "name": "arm 2", - "description": "arm 2 description", - "mu_init": 0, - "sigma_init": 1, - }, - ], - "contexts": [ - { - "name": "Context 1", - "description": "context 1 description", - "value_type": "binary", - }, - { - "name": "Context 2", - "description": "context 2 description", - "value_type": "real-valued", - }, - ], - "notifications": { - "onTrialCompletion": True, - "numberOfTrials": 2, - "onDaysElapsed": False, - "daysElapsed": 3, - "onPercentBetter": False, - "percentBetterThreshold": 5, - }, -} - -base_ab_payload = { - "name": "Test", - "description": "Test description", - "prior_type": "normal", - "reward_type": "real-valued", - "auto_fail": True, - "auto_fail_value": 3, - "auto_fail_unit": "hours", - "arms": [ - { - "name": "arm 1", - "description": "arm 1 description", - "mu_init": 0, - "sigma_init": 1, - "is_treatment_arm": True, - }, - { - "name": "arm 2", - "description": "arm 2 description", - "mu_init": 2, - "sigma_init": 2, - "is_treatment_arm": False, - }, - ], - "notifications": { - "onTrialCompletion": True, - "numberOfTrials": 2, - "onDaysElapsed": False, - "daysElapsed": 3, - "onPercentBetter": False, - "percentBetterThreshold": 5, - }, + "contexts": [], + "clients": [], } @@ -131,201 +54,47 @@ def now(cls, *arg: list) -> datetime: return mydatetime -class TestMABAutoFailJob: - @fixture - def create_mab_with_autofail( - self, - client: TestClient, - admin_token: str, - request: FixtureRequest, - ) -> Generator: - auto_fail_value, auto_fail_unit = request.param - mab_payload = copy.deepcopy(base_mab_payload) - mab_payload["auto_fail_value"] = auto_fail_value - mab_payload["auto_fail_unit"] = auto_fail_unit - - headers = {"Authorization": f"Bearer {admin_token}"} - response = client.post( - "/mab", - json=mab_payload, - headers=headers, - ) - assert response.status_code == 200 - mab = response.json() - yield mab - headers = {"Authorization": f"Bearer {admin_token}"} - client.delete(f"/mab/{mab['experiment_id']}", headers=headers) - - @mark.parametrize( - "create_mab_with_autofail, fail_value, fail_unit, n_observed", - [ - ((12, "hours"), 12, "hours", 2), - ((10, "days"), 10, "days", 3), - ((3, "hours"), 3, "hours", 0), - ((5, "days"), 5, "days", 0), - ], - indirect=["create_mab_with_autofail"], - ) - async def test_auto_fail_job( - self, - client: TestClient, - admin_token: str, - monkeypatch: MonkeyPatch, - create_mab_with_autofail: dict, - fail_value: int, - fail_unit: Literal["days", "hours"], - n_observed: int, - asession: AsyncSession, - workspace_api_key: str, - ) -> None: - draws = [] - headers = {"Authorization": f"Bearer {workspace_api_key}"} - for i in range(1, 15): - monkeypatch.setattr( - mab_models, - "datetime", - fake_datetime( - days=i if fail_unit == "days" else 0, - hours=i if fail_unit == "hours" else 0, - ), - ) - response = client.get( - f"/mab/{create_mab_with_autofail['experiment_id']}/draw", - headers=headers, - ) - assert response.status_code == 200 - draws.append(response.json()["draw_id"]) - - if i >= (15 - n_observed): - response = client.put( - f"/mab/{create_mab_with_autofail['experiment_id']}/{draws[-1]}/1", - headers=headers, - ) - assert response.status_code == 200 - - n_failed = await auto_fail_mab(asession=asession) - - assert n_failed == (15 - fail_value - n_observed) - - -class TestBayesABAutoFailJob: - @fixture - def create_bayes_ab_with_autofail( - self, - client: TestClient, - admin_token: str, - request: FixtureRequest, - ) -> Generator: - auto_fail_value, auto_fail_unit = request.param - ab_payload = copy.deepcopy(base_ab_payload) - ab_payload["auto_fail_value"] = auto_fail_value - ab_payload["auto_fail_unit"] = auto_fail_unit - - headers = {"Authorization": f"Bearer {admin_token}"} - response = client.post( - "/bayes_ab", - json=ab_payload, - headers=headers, - ) - assert response.status_code == 200 - ab = response.json() - yield ab - headers = {"Authorization": f"Bearer {admin_token}"} - client.delete(f"/bayes_ab/{ab['experiment_id']}", headers=headers) - - @mark.parametrize( - "create_bayes_ab_with_autofail, fail_value, fail_unit, n_observed", - [ - ((12, "hours"), 12, "hours", 2), - ((10, "days"), 10, "days", 3), - ((3, "hours"), 3, "hours", 0), - ((5, "days"), 5, "days", 0), - ], - indirect=["create_bayes_ab_with_autofail"], - ) - async def test_auto_fail_job( - self, - client: TestClient, - admin_token: str, - monkeypatch: MonkeyPatch, - create_bayes_ab_with_autofail: dict, - fail_value: int, - fail_unit: Literal["days", "hours"], - n_observed: int, - asession: AsyncSession, - workspace_api_key: str, - ) -> None: - draws = [] - headers = {"Authorization": f"Bearer {workspace_api_key}"} - for i in range(1, 15): - monkeypatch.setattr( - bayes_ab_models, - "datetime", - fake_datetime( - days=i if fail_unit == "days" else 0, - hours=i if fail_unit == "hours" else 0, - ), - ) - response = client.get( - f"/bayes_ab/{create_bayes_ab_with_autofail['experiment_id']}/draw", - headers=headers, - ) - assert response.status_code == 200 - draws.append(response.json()["draw_id"]) - - if i >= (15 - n_observed): - response = client.put( - f"/bayes_ab/{create_bayes_ab_with_autofail['experiment_id']}/{draws[-1]}/1", - headers=headers, - ) - assert response.status_code == 200 - - n_failed = await auto_fail_bayes_ab(asession=asession) - - assert n_failed == (15 - fail_value - n_observed) - - -class TestCMABAutoFailJob: +class TestExperimentAutoFailJob: @fixture - def create_cmab_with_autofail( + def create_experiment_with_autofail( self, client: TestClient, admin_token: str, request: FixtureRequest, ) -> Generator: auto_fail_value, auto_fail_unit = request.param - cmab_payload = copy.deepcopy(base_cmab_payload) - cmab_payload["auto_fail_value"] = auto_fail_value - cmab_payload["auto_fail_unit"] = auto_fail_unit + experiment_payload = copy.deepcopy(base_experiment_payload) + experiment_payload["auto_fail_value"] = auto_fail_value + experiment_payload["auto_fail_unit"] = auto_fail_unit headers = {"Authorization": f"Bearer {admin_token}"} response = client.post( - "/contextual_mab", - json=cmab_payload, + "/experiment", + json=experiment_payload, headers=headers, ) assert response.status_code == 200 - cmab = response.json() - yield cmab + experiment = response.json() + yield experiment headers = {"Authorization": f"Bearer {admin_token}"} - client.delete(f"/contextual_mab/{cmab['experiment_id']}", headers=headers) + client.delete(f"/experiment/id/{experiment['experiment_id']}", headers=headers) @mark.parametrize( - "create_cmab_with_autofail, fail_value, fail_unit, n_observed", + "create_experiment_with_autofail, fail_value, fail_unit, n_observed", [ ((12, "hours"), 12, "hours", 2), ((10, "days"), 10, "days", 3), ((3, "hours"), 3, "hours", 0), ((5, "days"), 5, "days", 0), ], - indirect=["create_cmab_with_autofail"], + indirect=["create_experiment_with_autofail"], ) async def test_auto_fail_job( self, client: TestClient, admin_token: str, monkeypatch: MonkeyPatch, - create_cmab_with_autofail: dict, + create_experiment_with_autofail: dict, fail_value: int, fail_unit: Literal["days", "hours"], n_observed: int, @@ -336,19 +105,15 @@ async def test_auto_fail_job( headers = {"Authorization": f"Bearer {workspace_api_key}"} for i in range(1, 15): monkeypatch.setattr( - cmab_models, + models, "datetime", fake_datetime( days=i if fail_unit == "days" else 0, hours=i if fail_unit == "hours" else 0, ), ) - response = client.post( - f"/contextual_mab/{create_cmab_with_autofail['experiment_id']}/draw", - json=[ - {"context_id": 1, "context_value": 0}, - {"context_id": 2, "context_value": 0}, - ], + response = client.put( + f"/experiment/{create_experiment_with_autofail['experiment_id']}/draw", headers=headers, ) assert response.status_code == 200 @@ -356,11 +121,12 @@ async def test_auto_fail_job( if i >= (15 - n_observed): response = client.put( - f"/contextual_mab/{create_cmab_with_autofail['experiment_id']}/{draws[-1]}/1", + f"/experiment/{create_experiment_with_autofail['experiment_id']}/{draws[-1]}/1", headers=headers, ) + print(response.json()) assert response.status_code == 200 - n_failed = await auto_fail_cmab(asession=asession) + n_failed = await auto_fail_experiment(asession=asession) assert n_failed == (15 - fail_value - n_observed) diff --git a/frontend/package-lock.json b/frontend/package-lock.json index bd3f82d..d4c756a 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -83,14 +83,25 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/@babel/runtime": { - "version": "7.27.0", - "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.27.0.tgz", - "integrity": "sha512-VtPOkrdPHZsKc/clNqyi9WUA8TINkZ4cGk63UUE3u4pmB2k+ZMQRDuIOagv8UVd6j7k0T3+RRIb7beKTebNbcw==", - "license": "MIT", + "node_modules/@ampproject/remapping": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/@ampproject/remapping/-/remapping-2.3.0.tgz", + "integrity": "sha512-30iZtAPgz+LTIYoeivqYo853f02jBYSd5uGnGpkFV0M3xOt9aN73erkgYAmZU43x4VfqcnLxW9Kpg3R5LC4YYw==", + "dev": true, + "license": "Apache-2.0", "dependencies": { - "regenerator-runtime": "^0.14.0" + "@jridgewell/gen-mapping": "^0.3.5", + "@jridgewell/trace-mapping": "^0.3.24" }, + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@babel/runtime": { + "version": "7.27.6", + "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.27.6.tgz", + "integrity": "sha512-vbavdySgbTTrmFE+EsiqUTzlOr5bzlnJtUv9PynGCAKvfQqjIXbvFdumPM/GxMDfyuGMJaJAU6TO4zc1Jf1i8Q==", + "license": "MIT", "engines": { "node": ">=6.9.0" } @@ -130,9 +141,9 @@ } }, "node_modules/@eslint-community/eslint-utils": { - "version": "4.6.1", - "resolved": "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.6.1.tgz", - "integrity": "sha512-KTsJMmobmbrFLe3LDh0PC2FXpcSYJt/MLjlkh/9LEnmKYLSYmT/0EW9JWANjeoemiuZrmogti0tW5Ch+qNUYDw==", + "version": "4.7.0", + "resolved": "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.7.0.tgz", + "integrity": "sha512-dyybb3AcajC7uha6CvhdVRJqaKyn7w2YKqKyAN37NKYgZT36w+iRb0Dymmc5qEJ549c/S31cMMSFd75bteCpCw==", "dev": true, "license": "MIT", "dependencies": { @@ -193,28 +204,28 @@ } }, "node_modules/@floating-ui/core": { - "version": "1.6.9", - "resolved": "https://registry.npmjs.org/@floating-ui/core/-/core-1.6.9.tgz", - "integrity": "sha512-uMXCuQ3BItDUbAMhIXw7UPXRfAlOAvZzdK9BWpE60MCn+Svt3aLn9jsPTi/WNGlRUu2uI0v5S7JiIUsbsvh3fw==", + "version": "1.7.1", + "resolved": "https://registry.npmjs.org/@floating-ui/core/-/core-1.7.1.tgz", + "integrity": "sha512-azI0DrjMMfIug/ExbBaeDVJXcY0a7EPvPjb2xAJPa4HeimBX+Z18HK8QQR3jb6356SnDDdxx+hinMLcJEDdOjw==", "license": "MIT", "dependencies": { "@floating-ui/utils": "^0.2.9" } }, "node_modules/@floating-ui/dom": { - "version": "1.6.13", - "resolved": "https://registry.npmjs.org/@floating-ui/dom/-/dom-1.6.13.tgz", - "integrity": "sha512-umqzocjDgNRGTuO7Q8CU32dkHkECqI8ZdMZ5Swb6QAM0t5rnlrN3lGo1hdpscRd3WS8T6DKYK4ephgIH9iRh3w==", + "version": "1.7.1", + "resolved": "https://registry.npmjs.org/@floating-ui/dom/-/dom-1.7.1.tgz", + "integrity": "sha512-cwsmW/zyw5ltYTUeeYJ60CnQuPqmGwuGVhG9w0PRaRKkAyi38BT5CKrpIbb+jtahSwUl04cWzSx9ZOIxeS6RsQ==", "license": "MIT", "dependencies": { - "@floating-ui/core": "^1.6.0", + "@floating-ui/core": "^1.7.1", "@floating-ui/utils": "^0.2.9" } }, "node_modules/@floating-ui/react-dom": { - "version": "2.1.2", - "resolved": "https://registry.npmjs.org/@floating-ui/react-dom/-/react-dom-2.1.2.tgz", - "integrity": "sha512-06okr5cgPzMNBy+Ycse2A6udMi4bqwW/zgBF/rwjcNqWkyr82Mcg8b0vjX8OJpZFy/FKjJmw6wV7t44kK6kW7A==", + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/@floating-ui/react-dom/-/react-dom-2.1.3.tgz", + "integrity": "sha512-huMBfiU9UnQ2oBwIhgzyIiSpVgvlDstU8CX0AF+wS+KzmYMs0J2a3GwuFHV1Lz+jlrQGeC1fF+Nv0QoumyV0bA==", "license": "MIT", "dependencies": { "@floating-ui/dom": "^1.0.0" @@ -324,23 +335,89 @@ "url": "https://github.com/chalk/strip-ansi?sponsor=1" } }, + "node_modules/@isaacs/fs-minipass": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/@isaacs/fs-minipass/-/fs-minipass-4.0.1.tgz", + "integrity": "sha512-wgm9Ehl2jpeqP3zw/7mo3kRHFp5MEDhqAdwy1fTGkHAwnkGOVsgpvQhL8B5n1qlb01jV3n/bI0ZfZp5lWA1k4w==", + "dev": true, + "license": "ISC", + "dependencies": { + "minipass": "^7.0.4" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@jridgewell/gen-mapping": { + "version": "0.3.8", + "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.8.tgz", + "integrity": "sha512-imAbBGkb+ebQyxKgzv5Hu2nmROxoDOXHh80evxdoXNOrvAnVx7zimzc1Oo5h9RlfV4vPXaE2iM5pOFbvOCClWA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@jridgewell/set-array": "^1.2.1", + "@jridgewell/sourcemap-codec": "^1.4.10", + "@jridgewell/trace-mapping": "^0.3.24" + }, + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@jridgewell/resolve-uri": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/@jridgewell/resolve-uri/-/resolve-uri-3.1.2.tgz", + "integrity": "sha512-bRISgCIjP20/tbWSPWMEi54QVPRZExkuD9lJL+UIxUKtwVJA8wW1Trb1jMs1RFXo1CBTNZ/5hpC9QvmKWdopKw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@jridgewell/set-array": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/@jridgewell/set-array/-/set-array-1.2.1.tgz", + "integrity": "sha512-R8gLRTZeyp03ymzP/6Lil/28tGeGEzhx1q2k703KGWRAI1VdvPIXdG70VJc2pAMw3NA6JKL5hhFu1sJX0Mnn/A==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@jridgewell/sourcemap-codec": { + "version": "1.5.0", + "resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.5.0.tgz", + "integrity": "sha512-gv3ZRaISU3fjPAgNsriBRqGWQL6quFx04YMPW/zD8XMLsU32mhCCbfbO6KZFLjvYpCZ8zyDEgqsgf+PwPaM7GQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/@jridgewell/trace-mapping": { + "version": "0.3.25", + "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.25.tgz", + "integrity": "sha512-vNk6aEwybGtawWmy/PzwnGDOjCkLWSD2wqvjGGAgOAwCGWySYXfYoxt00IJkTF+8Lb57DwOb3Aa0o9CApepiYQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@jridgewell/resolve-uri": "^3.1.0", + "@jridgewell/sourcemap-codec": "^1.4.14" + } + }, "node_modules/@napi-rs/wasm-runtime": { - "version": "0.2.9", - "resolved": "https://registry.npmjs.org/@napi-rs/wasm-runtime/-/wasm-runtime-0.2.9.tgz", - "integrity": "sha512-OKRBiajrrxB9ATokgEQoG87Z25c67pCpYcCwmXYX8PBftC9pBfN18gnm/fh1wurSLEKIAt+QRFLFCQISrb66Jg==", + "version": "0.2.10", + "resolved": "https://registry.npmjs.org/@napi-rs/wasm-runtime/-/wasm-runtime-0.2.10.tgz", + "integrity": "sha512-bCsCyeZEwVErsGmyPNSzwfwFn4OdxBj0mmv6hOFucB/k81Ojdu68RbZdxYsRQUPc9l6SU5F/cG+bXgWs3oUgsQ==", "dev": true, "license": "MIT", "optional": true, "dependencies": { - "@emnapi/core": "^1.4.0", - "@emnapi/runtime": "^1.4.0", + "@emnapi/core": "^1.4.3", + "@emnapi/runtime": "^1.4.3", "@tybys/wasm-util": "^0.9.0" } }, "node_modules/@next/env": { - "version": "14.2.28", - "resolved": "https://registry.npmjs.org/@next/env/-/env-14.2.28.tgz", - "integrity": "sha512-PAmWhJfJQlP+kxZwCjrVd9QnR5x0R3u0mTXTiZDgSd4h5LdXmjxCCWbN9kq6hkZBOax8Rm3xDW5HagWyJuT37g==", + "version": "14.2.29", + "resolved": "https://registry.npmjs.org/@next/env/-/env-14.2.29.tgz", + "integrity": "sha512-UzgLR2eBfhKIQt0aJ7PWH7XRPYw7SXz0Fpzdl5THjUnvxy4kfBk9OU4RNPNiETewEEtaBcExNFNn1QWH8wQTjg==", "license": "MIT" }, "node_modules/@next/eslint-plugin-next": { @@ -354,9 +431,9 @@ } }, "node_modules/@next/swc-darwin-arm64": { - "version": "14.2.28", - "resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-14.2.28.tgz", - "integrity": "sha512-kzGChl9setxYWpk3H6fTZXXPFFjg7urptLq5o5ZgYezCrqlemKttwMT5iFyx/p1e/JeglTwDFRtb923gTJ3R1w==", + "version": "14.2.29", + "resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-14.2.29.tgz", + "integrity": "sha512-wWtrAaxCVMejxPHFb1SK/PVV1WDIrXGs9ki0C/kUM8ubKHQm+3hU9MouUywCw8Wbhj3pewfHT2wjunLEr/TaLA==", "cpu": [ "arm64" ], @@ -370,9 +447,9 @@ } }, "node_modules/@next/swc-darwin-x64": { - "version": "14.2.28", - "resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-14.2.28.tgz", - "integrity": "sha512-z6FXYHDJlFOzVEOiiJ/4NG8aLCeayZdcRSMjPDysW297Up6r22xw6Ea9AOwQqbNsth8JNgIK8EkWz2IDwaLQcw==", + "version": "14.2.29", + "resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-14.2.29.tgz", + "integrity": "sha512-7Z/jk+6EVBj4pNLw/JQrvZVrAh9Bv8q81zCFSfvTMZ51WySyEHWVpwCEaJY910LyBftv2F37kuDPQm0w9CEXyg==", "cpu": [ "x64" ], @@ -386,9 +463,9 @@ } }, "node_modules/@next/swc-linux-arm64-gnu": { - "version": "14.2.28", - "resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-14.2.28.tgz", - "integrity": "sha512-9ARHLEQXhAilNJ7rgQX8xs9aH3yJSj888ssSjJLeldiZKR4D7N08MfMqljk77fAwZsWwsrp8ohHsMvurvv9liQ==", + "version": "14.2.29", + "resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-14.2.29.tgz", + "integrity": "sha512-o6hrz5xRBwi+G7JFTHc+RUsXo2lVXEfwh4/qsuWBMQq6aut+0w98WEnoNwAwt7hkEqegzvazf81dNiwo7KjITw==", "cpu": [ "arm64" ], @@ -402,9 +479,9 @@ } }, "node_modules/@next/swc-linux-arm64-musl": { - "version": "14.2.28", - "resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-14.2.28.tgz", - "integrity": "sha512-p6gvatI1nX41KCizEe6JkF0FS/cEEF0u23vKDpl+WhPe/fCTBeGkEBh7iW2cUM0rvquPVwPWdiUR6Ebr/kQWxQ==", + "version": "14.2.29", + "resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-14.2.29.tgz", + "integrity": "sha512-9i+JEHBOVgqxQ92HHRFlSW1EQXqa/89IVjtHgOqsShCcB/ZBjTtkWGi+SGCJaYyWkr/lzu51NTMCfKuBf7ULNw==", "cpu": [ "arm64" ], @@ -418,9 +495,9 @@ } }, "node_modules/@next/swc-linux-x64-gnu": { - "version": "14.2.28", - "resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-14.2.28.tgz", - "integrity": "sha512-nsiSnz2wO6GwMAX2o0iucONlVL7dNgKUqt/mDTATGO2NY59EO/ZKnKEr80BJFhuA5UC1KZOMblJHWZoqIJddpA==", + "version": "14.2.29", + "resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-14.2.29.tgz", + "integrity": "sha512-B7JtMbkUwHijrGBOhgSQu2ncbCYq9E7PZ7MX58kxheiEOwdkM+jGx0cBb+rN5AeqF96JypEppK6i/bEL9T13lA==", "cpu": [ "x64" ], @@ -434,9 +511,9 @@ } }, "node_modules/@next/swc-linux-x64-musl": { - "version": "14.2.28", - "resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-14.2.28.tgz", - "integrity": "sha512-+IuGQKoI3abrXFqx7GtlvNOpeExUH1mTIqCrh1LGFf8DnlUcTmOOCApEnPJUSLrSbzOdsF2ho2KhnQoO0I1RDw==", + "version": "14.2.29", + "resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-14.2.29.tgz", + "integrity": "sha512-yCcZo1OrO3aQ38B5zctqKU1Z3klOohIxug6qdiKO3Q3qNye/1n6XIs01YJ+Uf+TdpZQ0fNrOQI2HrTLF3Zprnw==", "cpu": [ "x64" ], @@ -450,9 +527,9 @@ } }, "node_modules/@next/swc-win32-arm64-msvc": { - "version": "14.2.28", - "resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-14.2.28.tgz", - "integrity": "sha512-l61WZ3nevt4BAnGksUVFKy2uJP5DPz2E0Ma/Oklvo3sGj9sw3q7vBWONFRgz+ICiHpW5mV+mBrkB3XEubMrKaA==", + "version": "14.2.29", + "resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-14.2.29.tgz", + "integrity": "sha512-WnrfeOEtTVidI9Z6jDLy+gxrpDcEJtZva54LYC0bSKQqmyuHzl0ego+v0F/v2aXq0am67BRqo/ybmmt45Tzo4A==", "cpu": [ "arm64" ], @@ -466,9 +543,9 @@ } }, "node_modules/@next/swc-win32-ia32-msvc": { - "version": "14.2.28", - "resolved": "https://registry.npmjs.org/@next/swc-win32-ia32-msvc/-/swc-win32-ia32-msvc-14.2.28.tgz", - "integrity": "sha512-+Kcp1T3jHZnJ9v9VTJ/yf1t/xmtFAc/Sge4v7mVc1z+NYfYzisi8kJ9AsY8itbgq+WgEwMtOpiLLJsUy2qnXZw==", + "version": "14.2.29", + "resolved": "https://registry.npmjs.org/@next/swc-win32-ia32-msvc/-/swc-win32-ia32-msvc-14.2.29.tgz", + "integrity": "sha512-vkcriFROT4wsTdSeIzbxaZjTNTFKjSYmLd8q/GVH3Dn8JmYjUKOuKXHK8n+lovW/kdcpIvydO5GtN+It2CvKWA==", "cpu": [ "ia32" ], @@ -482,9 +559,9 @@ } }, "node_modules/@next/swc-win32-x64-msvc": { - "version": "14.2.28", - "resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-14.2.28.tgz", - "integrity": "sha512-1gCmpvyhz7DkB1srRItJTnmR2UwQPAUXXIg9r0/56g3O8etGmwlX68skKXJOp9EejW3hhv7nSQUJ2raFiz4MoA==", + "version": "14.2.29", + "resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-14.2.29.tgz", + "integrity": "sha512-iPPwUEKnVs7pwR0EBLJlwxLD7TTHWS/AoVZx1l9ZQzfQciqaFEr5AlYzA2uB6Fyby1IF18t4PL0nTpB+k4Tzlw==", "cpu": [ "x64" ], @@ -575,12 +652,12 @@ "license": "MIT" }, "node_modules/@radix-ui/react-accessible-icon": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/@radix-ui/react-accessible-icon/-/react-accessible-icon-1.1.4.tgz", - "integrity": "sha512-J8pIt7l32A9fGIn86vwccQzik5MgIOTtceeTxi6EiiFYwWHLxsTHwiOW4pI5sQhQJWd3MOEkumFBIHwIU038Cw==", + "version": "1.1.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-accessible-icon/-/react-accessible-icon-1.1.7.tgz", + "integrity": "sha512-XM+E4WXl0OqUJFovy6GjmxxFyx9opfCAIUku4dlKRd5YEPqt4kALOkQOp0Of6reHuUkJuiPBEc5k0o4z4lTC8A==", "license": "MIT", "dependencies": { - "@radix-ui/react-visually-hidden": "1.2.0" + "@radix-ui/react-visually-hidden": "1.2.3" }, "peerDependencies": { "@types/react": "*", @@ -598,19 +675,19 @@ } }, "node_modules/@radix-ui/react-accordion": { - "version": "1.2.8", - "resolved": "https://registry.npmjs.org/@radix-ui/react-accordion/-/react-accordion-1.2.8.tgz", - "integrity": "sha512-c7OKBvO36PfQIUGIjj1Wko0hH937pYFU2tR5zbIJDUsmTzHoZVHHt4bmb7OOJbzTaWJtVELKWojBHa7OcnUHmQ==", + "version": "1.2.11", + "resolved": "https://registry.npmjs.org/@radix-ui/react-accordion/-/react-accordion-1.2.11.tgz", + "integrity": "sha512-l3W5D54emV2ues7jjeG1xcyN7S3jnK3zE2zHqgn0CmMsy9lNJwmgcrmaxS+7ipw15FAivzKNzH3d5EcGoFKw0A==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", - "@radix-ui/react-collapsible": "1.1.8", - "@radix-ui/react-collection": "1.1.4", + "@radix-ui/react-collapsible": "1.1.11", + "@radix-ui/react-collection": "1.1.7", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", "@radix-ui/react-id": "1.1.1", - "@radix-ui/react-primitive": "2.1.0", + "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-controllable-state": "1.2.2" }, "peerDependencies": { @@ -629,17 +706,17 @@ } }, "node_modules/@radix-ui/react-alert-dialog": { - "version": "1.1.11", - "resolved": "https://registry.npmjs.org/@radix-ui/react-alert-dialog/-/react-alert-dialog-1.1.11.tgz", - "integrity": "sha512-4KfkwrFnAw3Y5Jeoq6G+JYSKW0JfIS3uDdFC/79Jw9AsMayZMizSSMxk1gkrolYXsa/WzbbDfOA7/D8N5D+l1g==", + "version": "1.1.14", + "resolved": "https://registry.npmjs.org/@radix-ui/react-alert-dialog/-/react-alert-dialog-1.1.14.tgz", + "integrity": "sha512-IOZfZ3nPvN6lXpJTBCunFQPRSvK8MDgSc1FB85xnIpUKOw9en0dJj8JmCAxV7BiZdtYlUpmrQjoTFkVYtdoWzQ==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", - "@radix-ui/react-dialog": "1.1.11", - "@radix-ui/react-primitive": "2.1.0", - "@radix-ui/react-slot": "1.2.0" + "@radix-ui/react-dialog": "1.1.14", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-slot": "1.2.3" }, "peerDependencies": { "@types/react": "*", @@ -657,12 +734,12 @@ } }, "node_modules/@radix-ui/react-arrow": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/@radix-ui/react-arrow/-/react-arrow-1.1.4.tgz", - "integrity": "sha512-qz+fxrqgNxG0dYew5l7qR3c7wdgRu1XVUHGnGYX7rg5HM4p9SWaRmJwfgR3J0SgyUKayLmzQIun+N6rWRgiRKw==", + "version": "1.1.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-arrow/-/react-arrow-1.1.7.tgz", + "integrity": "sha512-F+M1tLhO+mlQaOWspE8Wstg+z6PwxwRd8oQ8IXceWz92kfAmalTRf0EjrouQeo7QssEPfCn05B4Ihs1K9WQ/7w==", "license": "MIT", "dependencies": { - "@radix-ui/react-primitive": "2.1.0" + "@radix-ui/react-primitive": "2.1.3" }, "peerDependencies": { "@types/react": "*", @@ -680,12 +757,12 @@ } }, "node_modules/@radix-ui/react-aspect-ratio": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/@radix-ui/react-aspect-ratio/-/react-aspect-ratio-1.1.4.tgz", - "integrity": "sha512-ie2mUDtM38LBqVU+Xn+GIY44tWM5yVbT5uXO+th85WZxUUsgEdWNNZWecqqGzkQ4Af+Fq1mYT6TyQ/uUf5gfcw==", + "version": "1.1.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-aspect-ratio/-/react-aspect-ratio-1.1.7.tgz", + "integrity": "sha512-Yq6lvO9HQyPwev1onK1daHCHqXVLzPhSVjmsNjCa2Zcxy2f7uJD2itDtxknv6FzAKCwD1qQkeVDmX/cev13n/g==", "license": "MIT", "dependencies": { - "@radix-ui/react-primitive": "2.1.0" + "@radix-ui/react-primitive": "2.1.3" }, "peerDependencies": { "@types/react": "*", @@ -703,13 +780,13 @@ } }, "node_modules/@radix-ui/react-avatar": { - "version": "1.1.7", - "resolved": "https://registry.npmjs.org/@radix-ui/react-avatar/-/react-avatar-1.1.7.tgz", - "integrity": "sha512-V7ODUt4mUoJTe3VUxZw6nfURxaPALVqmDQh501YmaQsk3D8AZQrOPRnfKn4H7JGDLBc0KqLhT94H79nV88ppNg==", + "version": "1.1.10", + "resolved": "https://registry.npmjs.org/@radix-ui/react-avatar/-/react-avatar-1.1.10.tgz", + "integrity": "sha512-V8piFfWapM5OmNCXTzVQY+E1rDa53zY+MQ4Y7356v4fFz6vqCyUtIz2rUD44ZEdwg78/jKmMJHj07+C/Z/rcog==", "license": "MIT", "dependencies": { "@radix-ui/react-context": "1.1.2", - "@radix-ui/react-primitive": "2.1.0", + "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-callback-ref": "1.1.1", "@radix-ui/react-use-is-hydrated": "0.1.0", "@radix-ui/react-use-layout-effect": "1.1.1" @@ -730,16 +807,16 @@ } }, "node_modules/@radix-ui/react-checkbox": { - "version": "1.2.3", - "resolved": "https://registry.npmjs.org/@radix-ui/react-checkbox/-/react-checkbox-1.2.3.tgz", - "integrity": "sha512-pHVzDYsnaDmBlAuwim45y3soIN8H4R7KbkSVirGhXO+R/kO2OLCe0eucUEbddaTcdMHHdzcIGHtZSMSQlA+apw==", + "version": "1.3.2", + "resolved": "https://registry.npmjs.org/@radix-ui/react-checkbox/-/react-checkbox-1.3.2.tgz", + "integrity": "sha512-yd+dI56KZqawxKZrJ31eENUwqc1QSqg4OZ15rybGjF2ZNwMO+wCyHzAVLRp9qoYJf7kYy0YpZ2b0JCzJ42HZpA==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-presence": "1.1.4", - "@radix-ui/react-primitive": "2.1.0", + "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-previous": "1.1.1", "@radix-ui/react-use-size": "1.1.1" @@ -760,9 +837,9 @@ } }, "node_modules/@radix-ui/react-collapsible": { - "version": "1.1.8", - "resolved": "https://registry.npmjs.org/@radix-ui/react-collapsible/-/react-collapsible-1.1.8.tgz", - "integrity": "sha512-hxEsLvK9WxIAPyxdDRULL4hcaSjMZCfP7fHB0Z1uUnDoDBat1Zh46hwYfa69DeZAbJrPckjf0AGAtEZyvDyJbw==", + "version": "1.1.11", + "resolved": "https://registry.npmjs.org/@radix-ui/react-collapsible/-/react-collapsible-1.1.11.tgz", + "integrity": "sha512-2qrRsVGSCYasSz1RFOorXwl0H7g7J1frQtgpQgYrt+MOidtPAINHn9CPovQXb83r8ahapdx3Tu0fa/pdFFSdPg==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", @@ -770,7 +847,7 @@ "@radix-ui/react-context": "1.1.2", "@radix-ui/react-id": "1.1.1", "@radix-ui/react-presence": "1.1.4", - "@radix-ui/react-primitive": "2.1.0", + "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-layout-effect": "1.1.1" }, @@ -790,15 +867,15 @@ } }, "node_modules/@radix-ui/react-collection": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/@radix-ui/react-collection/-/react-collection-1.1.4.tgz", - "integrity": "sha512-cv4vSf7HttqXilDnAnvINd53OTl1/bjUYVZrkFnA7nwmY9Ob2POUy0WY0sfqBAe1s5FyKsyceQlqiEGPYNTadg==", + "version": "1.1.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-collection/-/react-collection-1.1.7.tgz", + "integrity": "sha512-Fh9rGN0MoI4ZFUNyfFVNU4y9LUz93u9/0K+yLgA2bwRojxM8JU1DyvvMBabnZPBgMWREAJvU2jjVzq+LrFUglw==", "license": "MIT", "dependencies": { "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", - "@radix-ui/react-primitive": "2.1.0", - "@radix-ui/react-slot": "1.2.0" + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-slot": "1.2.3" }, "peerDependencies": { "@types/react": "*", @@ -846,15 +923,15 @@ } }, "node_modules/@radix-ui/react-context-menu": { - "version": "2.2.12", - "resolved": "https://registry.npmjs.org/@radix-ui/react-context-menu/-/react-context-menu-2.2.12.tgz", - "integrity": "sha512-5UFKuTMX8F2/KjHvyqu9IYT8bEtDSCJwwIx1PghBo4jh9S6jJVsceq9xIjqsOVcxsynGwV5eaqPE3n/Cu+DrSA==", + "version": "2.2.15", + "resolved": "https://registry.npmjs.org/@radix-ui/react-context-menu/-/react-context-menu-2.2.15.tgz", + "integrity": "sha512-UsQUMjcYTsBjTSXw0P3GO0werEQvUY2plgRQuKoCTtkNr45q1DiL51j4m7gxhABzZ0BadoXNsIbg7F3KwiUBbw==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-context": "1.1.2", - "@radix-ui/react-menu": "2.1.12", - "@radix-ui/react-primitive": "2.1.0", + "@radix-ui/react-menu": "2.1.15", + "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-callback-ref": "1.1.1", "@radix-ui/react-use-controllable-state": "1.2.2" }, @@ -874,22 +951,22 @@ } }, "node_modules/@radix-ui/react-dialog": { - "version": "1.1.11", - "resolved": "https://registry.npmjs.org/@radix-ui/react-dialog/-/react-dialog-1.1.11.tgz", - "integrity": "sha512-yI7S1ipkP5/+99qhSI6nthfo/tR6bL6Zgxi/+1UO6qPa6UeM6nlafWcQ65vB4rU2XjgjMfMhI3k9Y5MztA62VQ==", + "version": "1.1.14", + "resolved": "https://registry.npmjs.org/@radix-ui/react-dialog/-/react-dialog-1.1.14.tgz", + "integrity": "sha512-+CpweKjqpzTmwRwcYECQcNYbI8V9VSQt0SNFKeEBLgfucbsLssU6Ppq7wUdNXEGb573bMjFhVjKVll8rmV6zMw==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", - "@radix-ui/react-dismissable-layer": "1.1.7", + "@radix-ui/react-dismissable-layer": "1.1.10", "@radix-ui/react-focus-guards": "1.1.2", - "@radix-ui/react-focus-scope": "1.1.4", + "@radix-ui/react-focus-scope": "1.1.7", "@radix-ui/react-id": "1.1.1", - "@radix-ui/react-portal": "1.1.6", + "@radix-ui/react-portal": "1.1.9", "@radix-ui/react-presence": "1.1.4", - "@radix-ui/react-primitive": "2.1.0", - "@radix-ui/react-slot": "1.2.0", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-slot": "1.2.3", "@radix-ui/react-use-controllable-state": "1.2.2", "aria-hidden": "^1.2.4", "react-remove-scroll": "^2.6.3" @@ -925,14 +1002,14 @@ } }, "node_modules/@radix-ui/react-dismissable-layer": { - "version": "1.1.7", - "resolved": "https://registry.npmjs.org/@radix-ui/react-dismissable-layer/-/react-dismissable-layer-1.1.7.tgz", - "integrity": "sha512-j5+WBUdhccJsmH5/H0K6RncjDtoALSEr6jbkaZu+bjw6hOPOhHycr6vEUujl+HBK8kjUfWcoCJXxP6e4lUlMZw==", + "version": "1.1.10", + "resolved": "https://registry.npmjs.org/@radix-ui/react-dismissable-layer/-/react-dismissable-layer-1.1.10.tgz", + "integrity": "sha512-IM1zzRV4W3HtVgftdQiiOmA0AdJlCtMLe00FXaHwgt3rAnNsIyDqshvkIW3hj/iu5hu8ERP7KIYki6NkqDxAwQ==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-compose-refs": "1.1.2", - "@radix-ui/react-primitive": "2.1.0", + "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-callback-ref": "1.1.1", "@radix-ui/react-use-escape-keydown": "1.1.1" }, @@ -952,17 +1029,17 @@ } }, "node_modules/@radix-ui/react-dropdown-menu": { - "version": "2.1.12", - "resolved": "https://registry.npmjs.org/@radix-ui/react-dropdown-menu/-/react-dropdown-menu-2.1.12.tgz", - "integrity": "sha512-VJoMs+BWWE7YhzEQyVwvF9n22Eiyr83HotCVrMQzla/OwRovXCgah7AcaEr4hMNj4gJxSdtIbcHGvmJXOoJVHA==", + "version": "2.1.15", + "resolved": "https://registry.npmjs.org/@radix-ui/react-dropdown-menu/-/react-dropdown-menu-2.1.15.tgz", + "integrity": "sha512-mIBnOjgwo9AH3FyKaSWoSu/dYj6VdhJ7frEPiGTeXCdUFHjl9h3mFh2wwhEtINOmYXWhdpf1rY2minFsmaNgVQ==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-id": "1.1.1", - "@radix-ui/react-menu": "2.1.12", - "@radix-ui/react-primitive": "2.1.0", + "@radix-ui/react-menu": "2.1.15", + "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-controllable-state": "1.2.2" }, "peerDependencies": { @@ -996,13 +1073,13 @@ } }, "node_modules/@radix-ui/react-focus-scope": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/@radix-ui/react-focus-scope/-/react-focus-scope-1.1.4.tgz", - "integrity": "sha512-r2annK27lIW5w9Ho5NyQgqs0MmgZSTIKXWpVCJaLC1q2kZrZkcqnmHkCHMEmv8XLvsLlurKMPT+kbKkRkm/xVA==", + "version": "1.1.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-focus-scope/-/react-focus-scope-1.1.7.tgz", + "integrity": "sha512-t2ODlkXBQyn7jkl6TNaw/MtVEVvIGelJDCG41Okq/KwUsJBwQ4XVZsHAVUkK4mBv3ewiAS3PGuUWuY2BoK4ZUw==", "license": "MIT", "dependencies": { "@radix-ui/react-compose-refs": "1.1.2", - "@radix-ui/react-primitive": "2.1.0", + "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-callback-ref": "1.1.1" }, "peerDependencies": { @@ -1021,17 +1098,17 @@ } }, "node_modules/@radix-ui/react-form": { - "version": "0.1.4", - "resolved": "https://registry.npmjs.org/@radix-ui/react-form/-/react-form-0.1.4.tgz", - "integrity": "sha512-97Q7Hb0///sMF2X8XvyVx3Aub7WG/ybIofoDVUo8utG/z/6TBzWGjgai7ZjECXYLbKip88t9/ibyQJvYe5k6SA==", + "version": "0.1.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-form/-/react-form-0.1.7.tgz", + "integrity": "sha512-IXLKFnaYvFg/KkeV5QfOX7tRnwHXp127koOFUjLWMTrRv5Rny3DQcAtIFFeA/Cli4HHM8DuJCXAUsgnFVJndlw==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-id": "1.1.1", - "@radix-ui/react-label": "2.1.4", - "@radix-ui/react-primitive": "2.1.0" + "@radix-ui/react-label": "2.1.7", + "@radix-ui/react-primitive": "2.1.3" }, "peerDependencies": { "@types/react": "*", @@ -1049,19 +1126,19 @@ } }, "node_modules/@radix-ui/react-hover-card": { - "version": "1.1.11", - "resolved": "https://registry.npmjs.org/@radix-ui/react-hover-card/-/react-hover-card-1.1.11.tgz", - "integrity": "sha512-q9h9grUpGZKR3MNhtVCLVnPGmx1YnzBgGR+O40mhSNGsUnkR+LChVH8c7FB0mkS+oudhd8KAkZGTJPJCjdAPIg==", + "version": "1.1.14", + "resolved": "https://registry.npmjs.org/@radix-ui/react-hover-card/-/react-hover-card-1.1.14.tgz", + "integrity": "sha512-CPYZ24Mhirm+g6D8jArmLzjYu4Eyg3TTUHswR26QgzXBHBe64BO/RHOJKzmF/Dxb4y4f9PKyJdwm/O/AhNkb+Q==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", - "@radix-ui/react-dismissable-layer": "1.1.7", - "@radix-ui/react-popper": "1.2.4", - "@radix-ui/react-portal": "1.1.6", + "@radix-ui/react-dismissable-layer": "1.1.10", + "@radix-ui/react-popper": "1.2.7", + "@radix-ui/react-portal": "1.1.9", "@radix-ui/react-presence": "1.1.4", - "@radix-ui/react-primitive": "2.1.0", + "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-controllable-state": "1.2.2" }, "peerDependencies": { @@ -1107,12 +1184,12 @@ } }, "node_modules/@radix-ui/react-label": { - "version": "2.1.4", - "resolved": "https://registry.npmjs.org/@radix-ui/react-label/-/react-label-2.1.4.tgz", - "integrity": "sha512-wy3dqizZnZVV4ja0FNnUhIWNwWdoldXrneEyUcVtLYDAt8ovGS4ridtMAOGgXBBIfggL4BOveVWsjXDORdGEQg==", + "version": "2.1.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-label/-/react-label-2.1.7.tgz", + "integrity": "sha512-YT1GqPSL8kJn20djelMX7/cTRp/Y9w5IZHvfxQTVHrOqa2yMl7i/UfMqKRU5V7mEyKTrUVgJXhNQPVCG8PBLoQ==", "license": "MIT", "dependencies": { - "@radix-ui/react-primitive": "2.1.0" + "@radix-ui/react-primitive": "2.1.3" }, "peerDependencies": { "@types/react": "*", @@ -1130,26 +1207,26 @@ } }, "node_modules/@radix-ui/react-menu": { - "version": "2.1.12", - "resolved": "https://registry.npmjs.org/@radix-ui/react-menu/-/react-menu-2.1.12.tgz", - "integrity": "sha512-+qYq6LfbiGo97Zz9fioX83HCiIYYFNs8zAsVCMQrIakoNYylIzWuoD/anAD3UzvvR6cnswmfRFJFq/zYYq/k7Q==", + "version": "2.1.15", + "resolved": "https://registry.npmjs.org/@radix-ui/react-menu/-/react-menu-2.1.15.tgz", + "integrity": "sha512-tVlmA3Vb9n8SZSd+YSbuFR66l87Wiy4du+YE+0hzKQEANA+7cWKH1WgqcEX4pXqxUFQKrWQGHdvEfw00TjFiew==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", - "@radix-ui/react-collection": "1.1.4", + "@radix-ui/react-collection": "1.1.7", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", - "@radix-ui/react-dismissable-layer": "1.1.7", + "@radix-ui/react-dismissable-layer": "1.1.10", "@radix-ui/react-focus-guards": "1.1.2", - "@radix-ui/react-focus-scope": "1.1.4", + "@radix-ui/react-focus-scope": "1.1.7", "@radix-ui/react-id": "1.1.1", - "@radix-ui/react-popper": "1.2.4", - "@radix-ui/react-portal": "1.1.6", + "@radix-ui/react-popper": "1.2.7", + "@radix-ui/react-portal": "1.1.9", "@radix-ui/react-presence": "1.1.4", - "@radix-ui/react-primitive": "2.1.0", - "@radix-ui/react-roving-focus": "1.1.7", - "@radix-ui/react-slot": "1.2.0", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-roving-focus": "1.1.10", + "@radix-ui/react-slot": "1.2.3", "@radix-ui/react-use-callback-ref": "1.1.1", "aria-hidden": "^1.2.4", "react-remove-scroll": "^2.6.3" @@ -1170,20 +1247,20 @@ } }, "node_modules/@radix-ui/react-menubar": { - "version": "1.1.12", - "resolved": "https://registry.npmjs.org/@radix-ui/react-menubar/-/react-menubar-1.1.12.tgz", - "integrity": "sha512-bM2vT5nxRqJH/d1vFQ9jLsW4qR70yFQw2ZD1TUPWUNskDsV0eYeMbbNJqxNjGMOVogEkOJaHtu11kzYdTJvVJg==", + "version": "1.1.15", + "resolved": "https://registry.npmjs.org/@radix-ui/react-menubar/-/react-menubar-1.1.15.tgz", + "integrity": "sha512-Z71C7LGD+YDYo3TV81paUs8f3Zbmkvg6VLRQpKYfzioOE6n7fOhA3ApK/V/2Odolxjoc4ENk8AYCjohCNayd5A==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", - "@radix-ui/react-collection": "1.1.4", + "@radix-ui/react-collection": "1.1.7", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", "@radix-ui/react-id": "1.1.1", - "@radix-ui/react-menu": "2.1.12", - "@radix-ui/react-primitive": "2.1.0", - "@radix-ui/react-roving-focus": "1.1.7", + "@radix-ui/react-menu": "2.1.15", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-roving-focus": "1.1.10", "@radix-ui/react-use-controllable-state": "1.2.2" }, "peerDependencies": { @@ -1202,25 +1279,25 @@ } }, "node_modules/@radix-ui/react-navigation-menu": { - "version": "1.2.10", - "resolved": "https://registry.npmjs.org/@radix-ui/react-navigation-menu/-/react-navigation-menu-1.2.10.tgz", - "integrity": "sha512-kGDqMVPj2SRB1vJmXN/jnhC66REAXNyDmDRubbbmJ+360zSIJUDmWGMKIJOf72PHMwPENrbtJVb3CMAUJDjEIA==", + "version": "1.2.13", + "resolved": "https://registry.npmjs.org/@radix-ui/react-navigation-menu/-/react-navigation-menu-1.2.13.tgz", + "integrity": "sha512-WG8wWfDiJlSF5hELjwfjSGOXcBR/ZMhBFCGYe8vERpC39CQYZeq1PQ2kaYHdye3V95d06H89KGMsVCIE4LWo3g==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", - "@radix-ui/react-collection": "1.1.4", + "@radix-ui/react-collection": "1.1.7", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", - "@radix-ui/react-dismissable-layer": "1.1.7", + "@radix-ui/react-dismissable-layer": "1.1.10", "@radix-ui/react-id": "1.1.1", "@radix-ui/react-presence": "1.1.4", - "@radix-ui/react-primitive": "2.1.0", + "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-callback-ref": "1.1.1", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-layout-effect": "1.1.1", "@radix-ui/react-use-previous": "1.1.1", - "@radix-ui/react-visually-hidden": "1.2.0" + "@radix-ui/react-visually-hidden": "1.2.3" }, "peerDependencies": { "@types/react": "*", @@ -1238,19 +1315,19 @@ } }, "node_modules/@radix-ui/react-one-time-password-field": { - "version": "0.1.4", - "resolved": "https://registry.npmjs.org/@radix-ui/react-one-time-password-field/-/react-one-time-password-field-0.1.4.tgz", - "integrity": "sha512-CygYLHY8kO1De5iAZBn7gQbIoRNVGYx1paIyqbmwlxP6DF7sF1LLW3chXo/qxc4IWUQnsgAhfl9u6IoLXTndqQ==", + "version": "0.1.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-one-time-password-field/-/react-one-time-password-field-0.1.7.tgz", + "integrity": "sha512-w1vm7AGI8tNXVovOK7TYQHrAGpRF7qQL+ENpT1a743De5Zmay2RbWGKAiYDKIyIuqptns+znCKwNztE2xl1n0Q==", "license": "MIT", "dependencies": { "@radix-ui/number": "1.1.1", "@radix-ui/primitive": "1.1.2", - "@radix-ui/react-collection": "1.1.4", + "@radix-ui/react-collection": "1.1.7", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", - "@radix-ui/react-primitive": "2.1.0", - "@radix-ui/react-roving-focus": "1.1.7", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-roving-focus": "1.1.10", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-effect-event": "0.0.2", "@radix-ui/react-use-is-hydrated": "0.1.0", @@ -1271,24 +1348,54 @@ } } }, + "node_modules/@radix-ui/react-password-toggle-field": { + "version": "0.1.2", + "resolved": "https://registry.npmjs.org/@radix-ui/react-password-toggle-field/-/react-password-toggle-field-0.1.2.tgz", + "integrity": "sha512-F90uYnlBsLPU1UbSLciLsWQmk8+hdWa6SFw4GXaIdNWxFxI5ITKVdAG64f+Twaa9ic6xE7pqxPyUmodrGjT4pQ==", + "license": "MIT", + "dependencies": { + "@radix-ui/primitive": "1.1.2", + "@radix-ui/react-compose-refs": "1.1.2", + "@radix-ui/react-context": "1.1.2", + "@radix-ui/react-id": "1.1.1", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-use-controllable-state": "1.2.2", + "@radix-ui/react-use-effect-event": "0.0.2", + "@radix-ui/react-use-is-hydrated": "0.1.0" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, "node_modules/@radix-ui/react-popover": { - "version": "1.1.11", - "resolved": "https://registry.npmjs.org/@radix-ui/react-popover/-/react-popover-1.1.11.tgz", - "integrity": "sha512-yFMfZkVA5G3GJnBgb2PxrrcLKm1ZLWXrbYVgdyTl//0TYEIHS9LJbnyz7WWcZ0qCq7hIlJZpRtxeSeIG5T5oJw==", + "version": "1.1.14", + "resolved": "https://registry.npmjs.org/@radix-ui/react-popover/-/react-popover-1.1.14.tgz", + "integrity": "sha512-ODz16+1iIbGUfFEfKx2HTPKizg2MN39uIOV8MXeHnmdd3i/N9Wt7vU46wbHsqA0xoaQyXVcs0KIlBdOA2Y95bw==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", - "@radix-ui/react-dismissable-layer": "1.1.7", + "@radix-ui/react-dismissable-layer": "1.1.10", "@radix-ui/react-focus-guards": "1.1.2", - "@radix-ui/react-focus-scope": "1.1.4", + "@radix-ui/react-focus-scope": "1.1.7", "@radix-ui/react-id": "1.1.1", - "@radix-ui/react-popper": "1.2.4", - "@radix-ui/react-portal": "1.1.6", + "@radix-ui/react-popper": "1.2.7", + "@radix-ui/react-portal": "1.1.9", "@radix-ui/react-presence": "1.1.4", - "@radix-ui/react-primitive": "2.1.0", - "@radix-ui/react-slot": "1.2.0", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-slot": "1.2.3", "@radix-ui/react-use-controllable-state": "1.2.2", "aria-hidden": "^1.2.4", "react-remove-scroll": "^2.6.3" @@ -1309,16 +1416,16 @@ } }, "node_modules/@radix-ui/react-popper": { - "version": "1.2.4", - "resolved": "https://registry.npmjs.org/@radix-ui/react-popper/-/react-popper-1.2.4.tgz", - "integrity": "sha512-3p2Rgm/a1cK0r/UVkx5F/K9v/EplfjAeIFCGOPYPO4lZ0jtg4iSQXt/YGTSLWaf4x7NG6Z4+uKFcylcTZjeqDA==", + "version": "1.2.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-popper/-/react-popper-1.2.7.tgz", + "integrity": "sha512-IUFAccz1JyKcf/RjB552PlWwxjeCJB8/4KxT7EhBHOJM+mN7LdW+B3kacJXILm32xawcMMjb2i0cIZpo+f9kiQ==", "license": "MIT", "dependencies": { "@floating-ui/react-dom": "^2.0.0", - "@radix-ui/react-arrow": "1.1.4", + "@radix-ui/react-arrow": "1.1.7", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", - "@radix-ui/react-primitive": "2.1.0", + "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-callback-ref": "1.1.1", "@radix-ui/react-use-layout-effect": "1.1.1", "@radix-ui/react-use-rect": "1.1.1", @@ -1341,12 +1448,12 @@ } }, "node_modules/@radix-ui/react-portal": { - "version": "1.1.6", - "resolved": "https://registry.npmjs.org/@radix-ui/react-portal/-/react-portal-1.1.6.tgz", - "integrity": "sha512-XmsIl2z1n/TsYFLIdYam2rmFwf9OC/Sh2avkbmVMDuBZIe7hSpM0cYnWPAo7nHOVx8zTuwDZGByfcqLdnzp3Vw==", + "version": "1.1.9", + "resolved": "https://registry.npmjs.org/@radix-ui/react-portal/-/react-portal-1.1.9.tgz", + "integrity": "sha512-bpIxvq03if6UNwXZ+HTK71JLh4APvnXntDc6XOX8UVq4XQOVl7lwok0AvIl+b8zgCw3fSaVTZMpAPPagXbKmHQ==", "license": "MIT", "dependencies": { - "@radix-ui/react-primitive": "2.1.0", + "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-layout-effect": "1.1.1" }, "peerDependencies": { @@ -1389,12 +1496,12 @@ } }, "node_modules/@radix-ui/react-primitive": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/@radix-ui/react-primitive/-/react-primitive-2.1.0.tgz", - "integrity": "sha512-/J/FhLdK0zVcILOwt5g+dH4KnkonCtkVJsa2G6JmvbbtZfBEI1gMsO3QMjseL4F/SwfAMt1Vc/0XKYKq+xJ1sw==", + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-primitive/-/react-primitive-2.1.3.tgz", + "integrity": "sha512-m9gTwRkhy2lvCPe6QJp4d3G1TYEUHn/FzJUtq9MjH46an1wJU+GdoGC5VLof8RX8Ft/DlpshApkhswDLZzHIcQ==", "license": "MIT", "dependencies": { - "@radix-ui/react-slot": "1.2.0" + "@radix-ui/react-slot": "1.2.3" }, "peerDependencies": { "@types/react": "*", @@ -1412,13 +1519,13 @@ } }, "node_modules/@radix-ui/react-progress": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/@radix-ui/react-progress/-/react-progress-1.1.4.tgz", - "integrity": "sha512-8rl9w7lJdcVPor47Dhws9mUHRHLE+8JEgyJRdNWCpGPa6HIlr3eh+Yn9gyx1CnCLbw5naHsI2gaO9dBWO50vzw==", + "version": "1.1.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-progress/-/react-progress-1.1.7.tgz", + "integrity": "sha512-vPdg/tF6YC/ynuBIJlk1mm7Le0VgW6ub6J2UWnTQ7/D23KXcPI1qy+0vBkgKgd38RCMJavBXpB83HPNFMTb0Fg==", "license": "MIT", "dependencies": { "@radix-ui/react-context": "1.1.2", - "@radix-ui/react-primitive": "2.1.0" + "@radix-ui/react-primitive": "2.1.3" }, "peerDependencies": { "@types/react": "*", @@ -1436,9 +1543,9 @@ } }, "node_modules/@radix-ui/react-radio-group": { - "version": "1.3.4", - "resolved": "https://registry.npmjs.org/@radix-ui/react-radio-group/-/react-radio-group-1.3.4.tgz", - "integrity": "sha512-N4J9QFdW5zcJNxxY/zwTXBN4Uc5VEuRM7ZLjNfnWoKmNvgrPtNNw4P8zY532O3qL6aPkaNO+gY9y6bfzmH4U1g==", + "version": "1.3.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-radio-group/-/react-radio-group-1.3.7.tgz", + "integrity": "sha512-9w5XhD0KPOrm92OTTE0SysH3sYzHsSTHNvZgUBo/VZ80VdYyB5RneDbc0dKpURS24IxkoFRu/hI0i4XyfFwY6g==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", @@ -1446,8 +1553,8 @@ "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", "@radix-ui/react-presence": "1.1.4", - "@radix-ui/react-primitive": "2.1.0", - "@radix-ui/react-roving-focus": "1.1.7", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-roving-focus": "1.1.10", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-previous": "1.1.1", "@radix-ui/react-use-size": "1.1.1" @@ -1468,18 +1575,18 @@ } }, "node_modules/@radix-ui/react-roving-focus": { - "version": "1.1.7", - "resolved": "https://registry.npmjs.org/@radix-ui/react-roving-focus/-/react-roving-focus-1.1.7.tgz", - "integrity": "sha512-C6oAg451/fQT3EGbWHbCQjYTtbyjNO1uzQgMzwyivcHT3GKNEmu1q3UuREhN+HzHAVtv3ivMVK08QlC+PkYw9Q==", + "version": "1.1.10", + "resolved": "https://registry.npmjs.org/@radix-ui/react-roving-focus/-/react-roving-focus-1.1.10.tgz", + "integrity": "sha512-dT9aOXUen9JSsxnMPv/0VqySQf5eDQ6LCk5Sw28kamz8wSOW2bJdlX2Bg5VUIIcV+6XlHpWTIuTPCf/UNIyq8Q==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", - "@radix-ui/react-collection": "1.1.4", + "@radix-ui/react-collection": "1.1.7", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", "@radix-ui/react-id": "1.1.1", - "@radix-ui/react-primitive": "2.1.0", + "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-callback-ref": "1.1.1", "@radix-ui/react-use-controllable-state": "1.2.2" }, @@ -1499,9 +1606,9 @@ } }, "node_modules/@radix-ui/react-scroll-area": { - "version": "1.2.6", - "resolved": "https://registry.npmjs.org/@radix-ui/react-scroll-area/-/react-scroll-area-1.2.6.tgz", - "integrity": "sha512-lj8OMlpPERXrQIHlEQdlXHJoRT52AMpBrgyPYylOhXYq5e/glsEdtOc/kCQlsTdtgN5U0iDbrrolDadvektJGQ==", + "version": "1.2.9", + "resolved": "https://registry.npmjs.org/@radix-ui/react-scroll-area/-/react-scroll-area-1.2.9.tgz", + "integrity": "sha512-YSjEfBXnhUELsO2VzjdtYYD4CfQjvao+lhhrX5XsHD7/cyUNzljF1FHEbgTPN7LH2MClfwRMIsYlqTYpKTTe2A==", "license": "MIT", "dependencies": { "@radix-ui/number": "1.1.1", @@ -1510,7 +1617,7 @@ "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", "@radix-ui/react-presence": "1.1.4", - "@radix-ui/react-primitive": "2.1.0", + "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-callback-ref": "1.1.1", "@radix-ui/react-use-layout-effect": "1.1.1" }, @@ -1530,30 +1637,30 @@ } }, "node_modules/@radix-ui/react-select": { - "version": "2.2.2", - "resolved": "https://registry.npmjs.org/@radix-ui/react-select/-/react-select-2.2.2.tgz", - "integrity": "sha512-HjkVHtBkuq+r3zUAZ/CvNWUGKPfuicGDbgtZgiQuFmNcV5F+Tgy24ep2nsAW2nFgvhGPJVqeBZa6KyVN0EyrBA==", + "version": "2.2.5", + "resolved": "https://registry.npmjs.org/@radix-ui/react-select/-/react-select-2.2.5.tgz", + "integrity": "sha512-HnMTdXEVuuyzx63ME0ut4+sEMYW6oouHWNGUZc7ddvUWIcfCva/AMoqEW/3wnEllriMWBa0RHspCYnfCWJQYmA==", "license": "MIT", "dependencies": { "@radix-ui/number": "1.1.1", "@radix-ui/primitive": "1.1.2", - "@radix-ui/react-collection": "1.1.4", + "@radix-ui/react-collection": "1.1.7", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", - "@radix-ui/react-dismissable-layer": "1.1.7", + "@radix-ui/react-dismissable-layer": "1.1.10", "@radix-ui/react-focus-guards": "1.1.2", - "@radix-ui/react-focus-scope": "1.1.4", + "@radix-ui/react-focus-scope": "1.1.7", "@radix-ui/react-id": "1.1.1", - "@radix-ui/react-popper": "1.2.4", - "@radix-ui/react-portal": "1.1.6", - "@radix-ui/react-primitive": "2.1.0", - "@radix-ui/react-slot": "1.2.0", + "@radix-ui/react-popper": "1.2.7", + "@radix-ui/react-portal": "1.1.9", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-slot": "1.2.3", "@radix-ui/react-use-callback-ref": "1.1.1", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-layout-effect": "1.1.1", "@radix-ui/react-use-previous": "1.1.1", - "@radix-ui/react-visually-hidden": "1.2.0", + "@radix-ui/react-visually-hidden": "1.2.3", "aria-hidden": "^1.2.4", "react-remove-scroll": "^2.6.3" }, @@ -1573,12 +1680,12 @@ } }, "node_modules/@radix-ui/react-separator": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/@radix-ui/react-separator/-/react-separator-1.1.4.tgz", - "integrity": "sha512-2fTm6PSiUm8YPq9W0E4reYuv01EE3aFSzt8edBiXqPHshF8N9+Kymt/k0/R+F3dkY5lQyB/zPtrP82phskLi7w==", + "version": "1.1.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-separator/-/react-separator-1.1.7.tgz", + "integrity": "sha512-0HEb8R9E8A+jZjvmFCy/J4xhbXy3TV+9XSnGJ3KvTtjlIUy/YQ/p6UYZvi7YbeoeXdyU9+Y3scizK6hkY37baA==", "license": "MIT", "dependencies": { - "@radix-ui/react-primitive": "2.1.0" + "@radix-ui/react-primitive": "2.1.3" }, "peerDependencies": { "@types/react": "*", @@ -1596,18 +1703,18 @@ } }, "node_modules/@radix-ui/react-slider": { - "version": "1.3.2", - "resolved": "https://registry.npmjs.org/@radix-ui/react-slider/-/react-slider-1.3.2.tgz", - "integrity": "sha512-oQnqfgSiYkxZ1MrF6672jw2/zZvpB+PJsrIc3Zm1zof1JHf/kj7WhmROw7JahLfOwYQ5/+Ip0rFORgF1tjSiaQ==", + "version": "1.3.5", + "resolved": "https://registry.npmjs.org/@radix-ui/react-slider/-/react-slider-1.3.5.tgz", + "integrity": "sha512-rkfe2pU2NBAYfGaxa3Mqosi7VZEWX5CxKaanRv0vZd4Zhl9fvQrg0VM93dv3xGLGfrHuoTRF3JXH8nb9g+B3fw==", "license": "MIT", "dependencies": { "@radix-ui/number": "1.1.1", "@radix-ui/primitive": "1.1.2", - "@radix-ui/react-collection": "1.1.4", + "@radix-ui/react-collection": "1.1.7", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", - "@radix-ui/react-primitive": "2.1.0", + "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-layout-effect": "1.1.1", "@radix-ui/react-use-previous": "1.1.1", @@ -1629,9 +1736,9 @@ } }, "node_modules/@radix-ui/react-slot": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.2.0.tgz", - "integrity": "sha512-ujc+V6r0HNDviYqIK3rW4ffgYiZ8g5DEHrGJVk4x7kTlLXRDILnKX9vAUYeIsLOoDpDJ0ujpqMkjH4w2ofuo6w==", + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.2.3.tgz", + "integrity": "sha512-aeNmHnBxbi2St0au6VBVC7JXFlhLlOnvIIlePNniyUNAClzmtAUEY8/pBiK3iHjufOlwA+c20/8jngo7xcrg8A==", "license": "MIT", "dependencies": { "@radix-ui/react-compose-refs": "1.1.2" @@ -1647,15 +1754,15 @@ } }, "node_modules/@radix-ui/react-switch": { - "version": "1.2.2", - "resolved": "https://registry.npmjs.org/@radix-ui/react-switch/-/react-switch-1.2.2.tgz", - "integrity": "sha512-7Z8n6L+ifMIIYZ83f28qWSceUpkXuslI2FJ34+kDMTiyj91ENdpdQ7VCidrzj5JfwfZTeano/BnGBbu/jqa5rQ==", + "version": "1.2.5", + "resolved": "https://registry.npmjs.org/@radix-ui/react-switch/-/react-switch-1.2.5.tgz", + "integrity": "sha512-5ijLkak6ZMylXsaImpZ8u4Rlf5grRmoc0p0QeX9VJtlrM4f5m3nCTX8tWga/zOA8PZYIR/t0p2Mnvd7InrJ6yQ==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", - "@radix-ui/react-primitive": "2.1.0", + "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-previous": "1.1.1", "@radix-ui/react-use-size": "1.1.1" @@ -1676,9 +1783,9 @@ } }, "node_modules/@radix-ui/react-tabs": { - "version": "1.1.9", - "resolved": "https://registry.npmjs.org/@radix-ui/react-tabs/-/react-tabs-1.1.9.tgz", - "integrity": "sha512-KIjtwciYvquiW/wAFkELZCVnaNLBsYNhTNcvl+zfMAbMhRkcvNuCLXDDd22L0j7tagpzVh/QwbFpwAATg7ILPw==", + "version": "1.1.12", + "resolved": "https://registry.npmjs.org/@radix-ui/react-tabs/-/react-tabs-1.1.12.tgz", + "integrity": "sha512-GTVAlRVrQrSw3cEARM0nAx73ixrWDPNZAruETn3oHCNP6SbZ/hNxdxp+u7VkIEv3/sFoLq1PfcHrl7Pnp0CDpw==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", @@ -1686,8 +1793,8 @@ "@radix-ui/react-direction": "1.1.1", "@radix-ui/react-id": "1.1.1", "@radix-ui/react-presence": "1.1.4", - "@radix-ui/react-primitive": "2.1.0", - "@radix-ui/react-roving-focus": "1.1.7", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-roving-focus": "1.1.10", "@radix-ui/react-use-controllable-state": "1.2.2" }, "peerDependencies": { @@ -1706,23 +1813,23 @@ } }, "node_modules/@radix-ui/react-toast": { - "version": "1.2.11", - "resolved": "https://registry.npmjs.org/@radix-ui/react-toast/-/react-toast-1.2.11.tgz", - "integrity": "sha512-Ed2mlOmT+tktOsu2NZBK1bCSHh/uqULu1vWOkpQTVq53EoOuZUZw7FInQoDB3uil5wZc2oe0XN9a7uVZB7/6AQ==", + "version": "1.2.14", + "resolved": "https://registry.npmjs.org/@radix-ui/react-toast/-/react-toast-1.2.14.tgz", + "integrity": "sha512-nAP5FBxBJGQ/YfUB+r+O6USFVkWq3gAInkxyEnmvEV5jtSbfDhfa4hwX8CraCnbjMLsE7XSf/K75l9xXY7joWg==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", - "@radix-ui/react-collection": "1.1.4", + "@radix-ui/react-collection": "1.1.7", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", - "@radix-ui/react-dismissable-layer": "1.1.7", - "@radix-ui/react-portal": "1.1.6", + "@radix-ui/react-dismissable-layer": "1.1.10", + "@radix-ui/react-portal": "1.1.9", "@radix-ui/react-presence": "1.1.4", - "@radix-ui/react-primitive": "2.1.0", + "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-callback-ref": "1.1.1", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-layout-effect": "1.1.1", - "@radix-ui/react-visually-hidden": "1.2.0" + "@radix-ui/react-visually-hidden": "1.2.3" }, "peerDependencies": { "@types/react": "*", @@ -1740,13 +1847,13 @@ } }, "node_modules/@radix-ui/react-toggle": { - "version": "1.1.6", - "resolved": "https://registry.npmjs.org/@radix-ui/react-toggle/-/react-toggle-1.1.6.tgz", - "integrity": "sha512-3SeJxKeO3TO1zVw1Nl++Cp0krYk6zHDHMCUXXVkosIzl6Nxcvb07EerQpyD2wXQSJ5RZajrYAmPaydU8Hk1IyQ==", + "version": "1.1.9", + "resolved": "https://registry.npmjs.org/@radix-ui/react-toggle/-/react-toggle-1.1.9.tgz", + "integrity": "sha512-ZoFkBBz9zv9GWer7wIjvdRxmh2wyc2oKWw6C6CseWd6/yq1DK/l5lJ+wnsmFwJZbBYqr02mrf8A2q/CVCuM3ZA==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", - "@radix-ui/react-primitive": "2.1.0", + "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-controllable-state": "1.2.2" }, "peerDependencies": { @@ -1765,17 +1872,17 @@ } }, "node_modules/@radix-ui/react-toggle-group": { - "version": "1.1.7", - "resolved": "https://registry.npmjs.org/@radix-ui/react-toggle-group/-/react-toggle-group-1.1.7.tgz", - "integrity": "sha512-GRaPJhxrRSOqAcmcX3MwRL/SZACkoYdmoY9/sg7Bd5DhBYsB2t4co0NxTvVW8H7jUmieQDQwRtUlZ5Ta8UbgJA==", + "version": "1.1.10", + "resolved": "https://registry.npmjs.org/@radix-ui/react-toggle-group/-/react-toggle-group-1.1.10.tgz", + "integrity": "sha512-kiU694Km3WFLTC75DdqgM/3Jauf3rD9wxeS9XtyWFKsBUeZA337lC+6uUazT7I1DhanZ5gyD5Stf8uf2dbQxOQ==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", - "@radix-ui/react-primitive": "2.1.0", - "@radix-ui/react-roving-focus": "1.1.7", - "@radix-ui/react-toggle": "1.1.6", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-roving-focus": "1.1.10", + "@radix-ui/react-toggle": "1.1.9", "@radix-ui/react-use-controllable-state": "1.2.2" }, "peerDependencies": { @@ -1794,18 +1901,18 @@ } }, "node_modules/@radix-ui/react-toolbar": { - "version": "1.1.7", - "resolved": "https://registry.npmjs.org/@radix-ui/react-toolbar/-/react-toolbar-1.1.7.tgz", - "integrity": "sha512-cL/3snRskM0f955waP+m4Pmr8+QOPpPsfoY5kM06k7eWP41diOcyjLEqSxpd/K9S7fpsV66yq4R6yN2sMwXc6Q==", + "version": "1.1.10", + "resolved": "https://registry.npmjs.org/@radix-ui/react-toolbar/-/react-toolbar-1.1.10.tgz", + "integrity": "sha512-jiwQsduEL++M4YBIurjSa+voD86OIytCod0/dbIxFZDLD8NfO1//keXYMfsW8BPcfqwoNjt+y06XcJqAb4KR7A==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", - "@radix-ui/react-primitive": "2.1.0", - "@radix-ui/react-roving-focus": "1.1.7", - "@radix-ui/react-separator": "1.1.4", - "@radix-ui/react-toggle-group": "1.1.7" + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-roving-focus": "1.1.10", + "@radix-ui/react-separator": "1.1.7", + "@radix-ui/react-toggle-group": "1.1.10" }, "peerDependencies": { "@types/react": "*", @@ -1823,23 +1930,23 @@ } }, "node_modules/@radix-ui/react-tooltip": { - "version": "1.2.4", - "resolved": "https://registry.npmjs.org/@radix-ui/react-tooltip/-/react-tooltip-1.2.4.tgz", - "integrity": "sha512-DyW8VVeeMSSLFvAmnVnCwvI3H+1tpJFHT50r+tdOoMse9XqYDBCcyux8u3G2y+LOpt7fPQ6KKH0mhs+ce1+Z5w==", + "version": "1.2.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-tooltip/-/react-tooltip-1.2.7.tgz", + "integrity": "sha512-Ap+fNYwKTYJ9pzqW+Xe2HtMRbQ/EeWkj2qykZ6SuEV4iS/o1bZI5ssJbk4D2r8XuDuOBVz/tIx2JObtuqU+5Zw==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", - "@radix-ui/react-dismissable-layer": "1.1.7", + "@radix-ui/react-dismissable-layer": "1.1.10", "@radix-ui/react-id": "1.1.1", - "@radix-ui/react-popper": "1.2.4", - "@radix-ui/react-portal": "1.1.6", + "@radix-ui/react-popper": "1.2.7", + "@radix-ui/react-portal": "1.1.9", "@radix-ui/react-presence": "1.1.4", - "@radix-ui/react-primitive": "2.1.0", - "@radix-ui/react-slot": "1.2.0", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-slot": "1.2.3", "@radix-ui/react-use-controllable-state": "1.2.2", - "@radix-ui/react-visually-hidden": "1.2.0" + "@radix-ui/react-visually-hidden": "1.2.3" }, "peerDependencies": { "@types/react": "*", @@ -2011,12 +2118,12 @@ } }, "node_modules/@radix-ui/react-visually-hidden": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/@radix-ui/react-visually-hidden/-/react-visually-hidden-1.2.0.tgz", - "integrity": "sha512-rQj0aAWOpCdCMRbI6pLQm8r7S2BM3YhTa0SzOYD55k+hJA8oo9J+H+9wLM9oMlZWOX/wJWPTzfDfmZkf7LvCfg==", + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-visually-hidden/-/react-visually-hidden-1.2.3.tgz", + "integrity": "sha512-pzJq12tEaaIhqjbzpCuv/OypJY/BPavOofm+dbab+MHLajy277+1lLm6JFcGgF5eskJ6mquGirhXY2GD/8u8Ug==", "license": "MIT", "dependencies": { - "@radix-ui/react-primitive": "2.1.0" + "@radix-ui/react-primitive": "2.1.3" }, "peerDependencies": { "@types/react": "*", @@ -2109,46 +2216,54 @@ } }, "node_modules/@tailwindcss/node": { - "version": "4.1.4", - "resolved": "https://registry.npmjs.org/@tailwindcss/node/-/node-4.1.4.tgz", - "integrity": "sha512-MT5118zaiO6x6hNA04OWInuAiP1YISXql8Z+/Y8iisV5nuhM8VXlyhRuqc2PEviPszcXI66W44bCIk500Oolhw==", + "version": "4.1.8", + "resolved": "https://registry.npmjs.org/@tailwindcss/node/-/node-4.1.8.tgz", + "integrity": "sha512-OWwBsbC9BFAJelmnNcrKuf+bka2ZxCE2A4Ft53Tkg4uoiE67r/PMEYwCsourC26E+kmxfwE0hVzMdxqeW+xu7Q==", "dev": true, "license": "MIT", "dependencies": { + "@ampproject/remapping": "^2.3.0", "enhanced-resolve": "^5.18.1", "jiti": "^2.4.2", - "lightningcss": "1.29.2", - "tailwindcss": "4.1.4" + "lightningcss": "1.30.1", + "magic-string": "^0.30.17", + "source-map-js": "^1.2.1", + "tailwindcss": "4.1.8" } }, "node_modules/@tailwindcss/oxide": { - "version": "4.1.4", - "resolved": "https://registry.npmjs.org/@tailwindcss/oxide/-/oxide-4.1.4.tgz", - "integrity": "sha512-p5wOpXyOJx7mKh5MXh5oKk+kqcz8T+bA3z/5VWWeQwFrmuBItGwz8Y2CHk/sJ+dNb9B0nYFfn0rj/cKHZyjahQ==", + "version": "4.1.8", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide/-/oxide-4.1.8.tgz", + "integrity": "sha512-d7qvv9PsM5N3VNKhwVUhpK6r4h9wtLkJ6lz9ZY9aeZgrUWk1Z8VPyqyDT9MZlem7GTGseRQHkeB1j3tC7W1P+A==", "dev": true, + "hasInstallScript": true, "license": "MIT", + "dependencies": { + "detect-libc": "^2.0.4", + "tar": "^7.4.3" + }, "engines": { "node": ">= 10" }, "optionalDependencies": { - "@tailwindcss/oxide-android-arm64": "4.1.4", - "@tailwindcss/oxide-darwin-arm64": "4.1.4", - "@tailwindcss/oxide-darwin-x64": "4.1.4", - "@tailwindcss/oxide-freebsd-x64": "4.1.4", - "@tailwindcss/oxide-linux-arm-gnueabihf": "4.1.4", - "@tailwindcss/oxide-linux-arm64-gnu": "4.1.4", - "@tailwindcss/oxide-linux-arm64-musl": "4.1.4", - "@tailwindcss/oxide-linux-x64-gnu": "4.1.4", - "@tailwindcss/oxide-linux-x64-musl": "4.1.4", - "@tailwindcss/oxide-wasm32-wasi": "4.1.4", - "@tailwindcss/oxide-win32-arm64-msvc": "4.1.4", - "@tailwindcss/oxide-win32-x64-msvc": "4.1.4" + "@tailwindcss/oxide-android-arm64": "4.1.8", + "@tailwindcss/oxide-darwin-arm64": "4.1.8", + "@tailwindcss/oxide-darwin-x64": "4.1.8", + "@tailwindcss/oxide-freebsd-x64": "4.1.8", + "@tailwindcss/oxide-linux-arm-gnueabihf": "4.1.8", + "@tailwindcss/oxide-linux-arm64-gnu": "4.1.8", + "@tailwindcss/oxide-linux-arm64-musl": "4.1.8", + "@tailwindcss/oxide-linux-x64-gnu": "4.1.8", + "@tailwindcss/oxide-linux-x64-musl": "4.1.8", + "@tailwindcss/oxide-wasm32-wasi": "4.1.8", + "@tailwindcss/oxide-win32-arm64-msvc": "4.1.8", + "@tailwindcss/oxide-win32-x64-msvc": "4.1.8" } }, "node_modules/@tailwindcss/oxide-android-arm64": { - "version": "4.1.4", - "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-android-arm64/-/oxide-android-arm64-4.1.4.tgz", - "integrity": "sha512-xMMAe/SaCN/vHfQYui3fqaBDEXMu22BVwQ33veLc8ep+DNy7CWN52L+TTG9y1K397w9nkzv+Mw+mZWISiqhmlA==", + "version": "4.1.8", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-android-arm64/-/oxide-android-arm64-4.1.8.tgz", + "integrity": "sha512-Fbz7qni62uKYceWYvUjRqhGfZKwhZDQhlrJKGtnZfuNtHFqa8wmr+Wn74CTWERiW2hn3mN5gTpOoxWKk0jRxjg==", "cpu": [ "arm64" ], @@ -2163,9 +2278,9 @@ } }, "node_modules/@tailwindcss/oxide-darwin-arm64": { - "version": "4.1.4", - "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-darwin-arm64/-/oxide-darwin-arm64-4.1.4.tgz", - "integrity": "sha512-JGRj0SYFuDuAGilWFBlshcexev2hOKfNkoX+0QTksKYq2zgF9VY/vVMq9m8IObYnLna0Xlg+ytCi2FN2rOL0Sg==", + "version": "4.1.8", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-darwin-arm64/-/oxide-darwin-arm64-4.1.8.tgz", + "integrity": "sha512-RdRvedGsT0vwVVDztvyXhKpsU2ark/BjgG0huo4+2BluxdXo8NDgzl77qh0T1nUxmM11eXwR8jA39ibvSTbi7A==", "cpu": [ "arm64" ], @@ -2180,9 +2295,9 @@ } }, "node_modules/@tailwindcss/oxide-darwin-x64": { - "version": "4.1.4", - "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-darwin-x64/-/oxide-darwin-x64-4.1.4.tgz", - "integrity": "sha512-sdDeLNvs3cYeWsEJ4H1DvjOzaGios4QbBTNLVLVs0XQ0V95bffT3+scptzYGPMjm7xv4+qMhCDrkHwhnUySEzA==", + "version": "4.1.8", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-darwin-x64/-/oxide-darwin-x64-4.1.8.tgz", + "integrity": "sha512-t6PgxjEMLp5Ovf7uMb2OFmb3kqzVTPPakWpBIFzppk4JE4ix0yEtbtSjPbU8+PZETpaYMtXvss2Sdkx8Vs4XRw==", "cpu": [ "x64" ], @@ -2197,9 +2312,9 @@ } }, "node_modules/@tailwindcss/oxide-freebsd-x64": { - "version": "4.1.4", - "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-freebsd-x64/-/oxide-freebsd-x64-4.1.4.tgz", - "integrity": "sha512-VHxAqxqdghM83HslPhRsNhHo91McsxRJaEnShJOMu8mHmEj9Ig7ToHJtDukkuLWLzLboh2XSjq/0zO6wgvykNA==", + "version": "4.1.8", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-freebsd-x64/-/oxide-freebsd-x64-4.1.8.tgz", + "integrity": "sha512-g8C8eGEyhHTqwPStSwZNSrOlyx0bhK/V/+zX0Y+n7DoRUzyS8eMbVshVOLJTDDC+Qn9IJnilYbIKzpB9n4aBsg==", "cpu": [ "x64" ], @@ -2214,9 +2329,9 @@ } }, "node_modules/@tailwindcss/oxide-linux-arm-gnueabihf": { - "version": "4.1.4", - "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-linux-arm-gnueabihf/-/oxide-linux-arm-gnueabihf-4.1.4.tgz", - "integrity": "sha512-OTU/m/eV4gQKxy9r5acuesqaymyeSCnsx1cFto/I1WhPmi5HDxX1nkzb8KYBiwkHIGg7CTfo/AcGzoXAJBxLfg==", + "version": "4.1.8", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-linux-arm-gnueabihf/-/oxide-linux-arm-gnueabihf-4.1.8.tgz", + "integrity": "sha512-Jmzr3FA4S2tHhaC6yCjac3rGf7hG9R6Gf2z9i9JFcuyy0u79HfQsh/thifbYTF2ic82KJovKKkIB6Z9TdNhCXQ==", "cpu": [ "arm" ], @@ -2231,9 +2346,9 @@ } }, "node_modules/@tailwindcss/oxide-linux-arm64-gnu": { - "version": "4.1.4", - "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-linux-arm64-gnu/-/oxide-linux-arm64-gnu-4.1.4.tgz", - "integrity": "sha512-hKlLNvbmUC6z5g/J4H+Zx7f7w15whSVImokLPmP6ff1QqTVE+TxUM9PGuNsjHvkvlHUtGTdDnOvGNSEUiXI1Ww==", + "version": "4.1.8", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-linux-arm64-gnu/-/oxide-linux-arm64-gnu-4.1.8.tgz", + "integrity": "sha512-qq7jXtO1+UEtCmCeBBIRDrPFIVI4ilEQ97qgBGdwXAARrUqSn/L9fUrkb1XP/mvVtoVeR2bt/0L77xx53bPZ/Q==", "cpu": [ "arm64" ], @@ -2248,9 +2363,9 @@ } }, "node_modules/@tailwindcss/oxide-linux-arm64-musl": { - "version": "4.1.4", - "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-linux-arm64-musl/-/oxide-linux-arm64-musl-4.1.4.tgz", - "integrity": "sha512-X3As2xhtgPTY/m5edUtddmZ8rCruvBvtxYLMw9OsZdH01L2gS2icsHRwxdU0dMItNfVmrBezueXZCHxVeeb7Aw==", + "version": "4.1.8", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-linux-arm64-musl/-/oxide-linux-arm64-musl-4.1.8.tgz", + "integrity": "sha512-O6b8QesPbJCRshsNApsOIpzKt3ztG35gfX9tEf4arD7mwNinsoCKxkj8TgEE0YRjmjtO3r9FlJnT/ENd9EVefQ==", "cpu": [ "arm64" ], @@ -2265,9 +2380,9 @@ } }, "node_modules/@tailwindcss/oxide-linux-x64-gnu": { - "version": "4.1.4", - "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-linux-x64-gnu/-/oxide-linux-x64-gnu-4.1.4.tgz", - "integrity": "sha512-2VG4DqhGaDSmYIu6C4ua2vSLXnJsb/C9liej7TuSO04NK+JJJgJucDUgmX6sn7Gw3Cs5ZJ9ZLrnI0QRDOjLfNQ==", + "version": "4.1.8", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-linux-x64-gnu/-/oxide-linux-x64-gnu-4.1.8.tgz", + "integrity": "sha512-32iEXX/pXwikshNOGnERAFwFSfiltmijMIAbUhnNyjFr3tmWmMJWQKU2vNcFX0DACSXJ3ZWcSkzNbaKTdngH6g==", "cpu": [ "x64" ], @@ -2282,9 +2397,9 @@ } }, "node_modules/@tailwindcss/oxide-linux-x64-musl": { - "version": "4.1.4", - "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-linux-x64-musl/-/oxide-linux-x64-musl-4.1.4.tgz", - "integrity": "sha512-v+mxVgH2kmur/X5Mdrz9m7TsoVjbdYQT0b4Z+dr+I4RvreCNXyCFELZL/DO0M1RsidZTrm6O1eMnV6zlgEzTMQ==", + "version": "4.1.8", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-linux-x64-musl/-/oxide-linux-x64-musl-4.1.8.tgz", + "integrity": "sha512-s+VSSD+TfZeMEsCaFaHTaY5YNj3Dri8rST09gMvYQKwPphacRG7wbuQ5ZJMIJXN/puxPcg/nU+ucvWguPpvBDg==", "cpu": [ "x64" ], @@ -2299,9 +2414,9 @@ } }, "node_modules/@tailwindcss/oxide-wasm32-wasi": { - "version": "4.1.4", - "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-wasm32-wasi/-/oxide-wasm32-wasi-4.1.4.tgz", - "integrity": "sha512-2TLe9ir+9esCf6Wm+lLWTMbgklIjiF0pbmDnwmhR9MksVOq+e8aP3TSsXySnBDDvTTVd/vKu1aNttEGj3P6l8Q==", + "version": "4.1.8", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-wasm32-wasi/-/oxide-wasm32-wasi-4.1.8.tgz", + "integrity": "sha512-CXBPVFkpDjM67sS1psWohZ6g/2/cd+cq56vPxK4JeawelxwK4YECgl9Y9TjkE2qfF+9/s1tHHJqrC4SS6cVvSg==", "bundleDependencies": [ "@napi-rs/wasm-runtime", "@emnapi/core", @@ -2317,10 +2432,10 @@ "license": "MIT", "optional": true, "dependencies": { - "@emnapi/core": "^1.4.0", - "@emnapi/runtime": "^1.4.0", - "@emnapi/wasi-threads": "^1.0.1", - "@napi-rs/wasm-runtime": "^0.2.8", + "@emnapi/core": "^1.4.3", + "@emnapi/runtime": "^1.4.3", + "@emnapi/wasi-threads": "^1.0.2", + "@napi-rs/wasm-runtime": "^0.2.10", "@tybys/wasm-util": "^0.9.0", "tslib": "^2.8.0" }, @@ -2329,9 +2444,9 @@ } }, "node_modules/@tailwindcss/oxide-win32-arm64-msvc": { - "version": "4.1.4", - "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-win32-arm64-msvc/-/oxide-win32-arm64-msvc-4.1.4.tgz", - "integrity": "sha512-VlnhfilPlO0ltxW9/BgfLI5547PYzqBMPIzRrk4W7uupgCt8z6Trw/tAj6QUtF2om+1MH281Pg+HHUJoLesmng==", + "version": "4.1.8", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-win32-arm64-msvc/-/oxide-win32-arm64-msvc-4.1.8.tgz", + "integrity": "sha512-7GmYk1n28teDHUjPlIx4Z6Z4hHEgvP5ZW2QS9ygnDAdI/myh3HTHjDqtSqgu1BpRoI4OiLx+fThAyA1JePoENA==", "cpu": [ "arm64" ], @@ -2346,9 +2461,9 @@ } }, "node_modules/@tailwindcss/oxide-win32-x64-msvc": { - "version": "4.1.4", - "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-win32-x64-msvc/-/oxide-win32-x64-msvc-4.1.4.tgz", - "integrity": "sha512-+7S63t5zhYjslUGb8NcgLpFXD+Kq1F/zt5Xv5qTv7HaFTG/DHyHD9GA6ieNAxhgyA4IcKa/zy7Xx4Oad2/wuhw==", + "version": "4.1.8", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-win32-x64-msvc/-/oxide-win32-x64-msvc-4.1.8.tgz", + "integrity": "sha512-fou+U20j+Jl0EHwK92spoWISON2OBnCazIc038Xj2TdweYV33ZRkS9nwqiUi2d/Wba5xg5UoHfvynnb/UB49cQ==", "cpu": [ "x64" ], @@ -2363,17 +2478,17 @@ } }, "node_modules/@tailwindcss/postcss": { - "version": "4.1.4", - "resolved": "https://registry.npmjs.org/@tailwindcss/postcss/-/postcss-4.1.4.tgz", - "integrity": "sha512-bjV6sqycCEa+AQSt2Kr7wpGF1bOZJ5wsqnLEkqSbM/JEHxx/yhMH8wHmdkPyApF9xhHeMSwnnkDUUMMM/hYnXw==", + "version": "4.1.8", + "resolved": "https://registry.npmjs.org/@tailwindcss/postcss/-/postcss-4.1.8.tgz", + "integrity": "sha512-vB/vlf7rIky+w94aWMw34bWW1ka6g6C3xIOdICKX2GC0VcLtL6fhlLiafF0DVIwa9V6EHz8kbWMkS2s2QvvNlw==", "dev": true, "license": "MIT", "dependencies": { "@alloc/quick-lru": "^5.2.0", - "@tailwindcss/node": "4.1.4", - "@tailwindcss/oxide": "4.1.4", + "@tailwindcss/node": "4.1.8", + "@tailwindcss/oxide": "4.1.8", "postcss": "^8.4.41", - "tailwindcss": "4.1.4" + "tailwindcss": "4.1.8" } }, "node_modules/@tybys/wasm-util": { @@ -2479,9 +2594,9 @@ "license": "MIT" }, "node_modules/@types/node": { - "version": "20.17.32", - "resolved": "https://registry.npmjs.org/@types/node/-/node-20.17.32.tgz", - "integrity": "sha512-zeMXFn8zQ+UkjK4ws0RiOC9EWByyW1CcVmLe+2rQocXRsGEDxUCwPEIVgpsGcLHS/P8JkT0oa3839BRABS0oPw==", + "version": "20.17.57", + "resolved": "https://registry.npmjs.org/@types/node/-/node-20.17.57.tgz", + "integrity": "sha512-f3T4y6VU4fVQDKVqJV4Uppy8c1p/sVvS3peyqxyWnzkqXFJLRU7Y1Bl7rMS1Qe9z0v4M6McY0Fp9yBsgHJUsWQ==", "dev": true, "license": "MIT", "dependencies": { @@ -2496,9 +2611,9 @@ "license": "MIT" }, "node_modules/@types/react": { - "version": "18.3.20", - "resolved": "https://registry.npmjs.org/@types/react/-/react-18.3.20.tgz", - "integrity": "sha512-IPaCZN7PShZK/3t6Q87pfTkRm6oLTd4vztyoj+cbHUF1g3FfVb2tFIL79uCRKEfv16AhqDMBywP2VW3KIZUvcg==", + "version": "18.3.23", + "resolved": "https://registry.npmjs.org/@types/react/-/react-18.3.23.tgz", + "integrity": "sha512-/LDXMQh55EzZQ0uVAZmKKhfENivEvWz6E+EYzh+/MCjMhNsotd+ZHhBGIjFDTi6+fz0OhQQQLbTgdQIxxCsC0w==", "devOptional": true, "license": "MIT", "dependencies": { @@ -2507,9 +2622,9 @@ } }, "node_modules/@types/react-dom": { - "version": "18.3.6", - "resolved": "https://registry.npmjs.org/@types/react-dom/-/react-dom-18.3.6.tgz", - "integrity": "sha512-nf22//wEbKXusP6E9pfOCDwFdHAX4u172eaJI4YkDRQEZiorm6KfYnSC2SWLDMVWUOWPERmJnN0ujeAfTBLvrw==", + "version": "18.3.7", + "resolved": "https://registry.npmjs.org/@types/react-dom/-/react-dom-18.3.7.tgz", + "integrity": "sha512-MEe3UeoENYVFXzoXEWsvcpg6ZvlrFNlOQ7EOsvhI3CfAXwzPfO8Qwuxd40nepsYKqyyVQnTdEfv68q91yLcKrQ==", "devOptional": true, "license": "MIT", "peerDependencies": { @@ -2727,9 +2842,9 @@ "license": "ISC" }, "node_modules/@unrs/resolver-binding-darwin-arm64": { - "version": "1.7.2", - "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-darwin-arm64/-/resolver-binding-darwin-arm64-1.7.2.tgz", - "integrity": "sha512-vxtBno4xvowwNmO/ASL0Y45TpHqmNkAaDtz4Jqb+clmcVSSl8XCG/PNFFkGsXXXS6AMjP+ja/TtNCFFa1QwLRg==", + "version": "1.7.10", + "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-darwin-arm64/-/resolver-binding-darwin-arm64-1.7.10.tgz", + "integrity": "sha512-ABsM3eEiL3yu903G0uxgvGAoIw011XjTzyEk//gGtuVY1PuXP2IJG6novd6DBjm7MaWmRV/CZFY1rWBXSlSVVw==", "cpu": [ "arm64" ], @@ -2741,9 +2856,9 @@ ] }, "node_modules/@unrs/resolver-binding-darwin-x64": { - "version": "1.7.2", - "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-darwin-x64/-/resolver-binding-darwin-x64-1.7.2.tgz", - "integrity": "sha512-qhVa8ozu92C23Hsmv0BF4+5Dyyd5STT1FolV4whNgbY6mj3kA0qsrGPe35zNR3wAN7eFict3s4Rc2dDTPBTuFQ==", + "version": "1.7.10", + "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-darwin-x64/-/resolver-binding-darwin-x64-1.7.10.tgz", + "integrity": "sha512-lGVWy4FQEDo/PuI1VQXaQCY0XUg4xUJilf3fQ8NY4wtsQTm9lbasbUYf3nkoma+O2/do90jQTqkb02S3meyTDg==", "cpu": [ "x64" ], @@ -2755,9 +2870,9 @@ ] }, "node_modules/@unrs/resolver-binding-freebsd-x64": { - "version": "1.7.2", - "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-freebsd-x64/-/resolver-binding-freebsd-x64-1.7.2.tgz", - "integrity": "sha512-zKKdm2uMXqLFX6Ac7K5ElnnG5VIXbDlFWzg4WJ8CGUedJryM5A3cTgHuGMw1+P5ziV8CRhnSEgOnurTI4vpHpg==", + "version": "1.7.10", + "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-freebsd-x64/-/resolver-binding-freebsd-x64-1.7.10.tgz", + "integrity": "sha512-g9XLCHzNGatY79JJNgxrUH6uAAfBDj2NWIlTnqQN5odwGKjyVfFZ5tFL1OxYPcxTHh384TY5lvTtF+fuEZNvBQ==", "cpu": [ "x64" ], @@ -2769,9 +2884,9 @@ ] }, "node_modules/@unrs/resolver-binding-linux-arm-gnueabihf": { - "version": "1.7.2", - "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-arm-gnueabihf/-/resolver-binding-linux-arm-gnueabihf-1.7.2.tgz", - "integrity": "sha512-8N1z1TbPnHH+iDS/42GJ0bMPLiGK+cUqOhNbMKtWJ4oFGzqSJk/zoXFzcQkgtI63qMcUI7wW1tq2usZQSb2jxw==", + "version": "1.7.10", + "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-arm-gnueabihf/-/resolver-binding-linux-arm-gnueabihf-1.7.10.tgz", + "integrity": "sha512-zV0ZMNy50sJFJapsjec8onyL9YREQKT88V8KwMoOA+zki/duFUP0oyTlbax1jGKdh8rQnruvW9VYkovGvdBAsw==", "cpu": [ "arm" ], @@ -2783,9 +2898,9 @@ ] }, "node_modules/@unrs/resolver-binding-linux-arm-musleabihf": { - "version": "1.7.2", - "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-arm-musleabihf/-/resolver-binding-linux-arm-musleabihf-1.7.2.tgz", - "integrity": "sha512-tjYzI9LcAXR9MYd9rO45m1s0B/6bJNuZ6jeOxo1pq1K6OBuRMMmfyvJYval3s9FPPGmrldYA3mi4gWDlWuTFGA==", + "version": "1.7.10", + "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-arm-musleabihf/-/resolver-binding-linux-arm-musleabihf-1.7.10.tgz", + "integrity": "sha512-jQxgb1DIDI7goyrabh4uvyWWBrFRfF+OOnS9SbF15h52g3Qjn/u8zG7wOQ0NjtcSMftzO75TITu9aHuI7FcqQQ==", "cpu": [ "arm" ], @@ -2797,9 +2912,9 @@ ] }, "node_modules/@unrs/resolver-binding-linux-arm64-gnu": { - "version": "1.7.2", - "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-arm64-gnu/-/resolver-binding-linux-arm64-gnu-1.7.2.tgz", - "integrity": "sha512-jon9M7DKRLGZ9VYSkFMflvNqu9hDtOCEnO2QAryFWgT6o6AXU8du56V7YqnaLKr6rAbZBWYsYpikF226v423QA==", + "version": "1.7.10", + "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-arm64-gnu/-/resolver-binding-linux-arm64-gnu-1.7.10.tgz", + "integrity": "sha512-9wVVlO6+aNlm90YWitwSI++HyCyBkzYCwMi7QbuGrTxDFm2pAgtpT0OEliaI7tLS8lAWYuDbzRRCJDgsdm6nwg==", "cpu": [ "arm64" ], @@ -2811,9 +2926,9 @@ ] }, "node_modules/@unrs/resolver-binding-linux-arm64-musl": { - "version": "1.7.2", - "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-arm64-musl/-/resolver-binding-linux-arm64-musl-1.7.2.tgz", - "integrity": "sha512-c8Cg4/h+kQ63pL43wBNaVMmOjXI/X62wQmru51qjfTvI7kmCy5uHTJvK/9LrF0G8Jdx8r34d019P1DVJmhXQpA==", + "version": "1.7.10", + "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-arm64-musl/-/resolver-binding-linux-arm64-musl-1.7.10.tgz", + "integrity": "sha512-FtFweORChdXOes0RAAyTZp6I4PodU2cZiSILAbGaEKDXp378UOumD2vaAkWHNxpsreQUKRxG5O1uq9EoV1NiVQ==", "cpu": [ "arm64" ], @@ -2825,9 +2940,9 @@ ] }, "node_modules/@unrs/resolver-binding-linux-ppc64-gnu": { - "version": "1.7.2", - "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-ppc64-gnu/-/resolver-binding-linux-ppc64-gnu-1.7.2.tgz", - "integrity": "sha512-A+lcwRFyrjeJmv3JJvhz5NbcCkLQL6Mk16kHTNm6/aGNc4FwPHPE4DR9DwuCvCnVHvF5IAd9U4VIs/VvVir5lg==", + "version": "1.7.10", + "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-ppc64-gnu/-/resolver-binding-linux-ppc64-gnu-1.7.10.tgz", + "integrity": "sha512-B+hOjpG2ncCR96a9d9ww1dWVuRVC2NChD0bITgrUhEWBhpdv2o/Mu2l8MsB2fzjdV/ku+twaQhr8iLHBoZafZQ==", "cpu": [ "ppc64" ], @@ -2839,9 +2954,9 @@ ] }, "node_modules/@unrs/resolver-binding-linux-riscv64-gnu": { - "version": "1.7.2", - "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-riscv64-gnu/-/resolver-binding-linux-riscv64-gnu-1.7.2.tgz", - "integrity": "sha512-hQQ4TJQrSQW8JlPm7tRpXN8OCNP9ez7PajJNjRD1ZTHQAy685OYqPrKjfaMw/8LiHCt8AZ74rfUVHP9vn0N69Q==", + "version": "1.7.10", + "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-riscv64-gnu/-/resolver-binding-linux-riscv64-gnu-1.7.10.tgz", + "integrity": "sha512-DS6jFDoQCFsnsdLXlj3z3THakQLBic63B6A0rpQ1kpkyKa3OzEfqhwRNVaywuUuOKP9bX55Jk2uqpvn/hGjKCg==", "cpu": [ "riscv64" ], @@ -2853,9 +2968,9 @@ ] }, "node_modules/@unrs/resolver-binding-linux-riscv64-musl": { - "version": "1.7.2", - "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-riscv64-musl/-/resolver-binding-linux-riscv64-musl-1.7.2.tgz", - "integrity": "sha512-NoAGbiqrxtY8kVooZ24i70CjLDlUFI7nDj3I9y54U94p+3kPxwd2L692YsdLa+cqQ0VoqMWoehDFp21PKRUoIQ==", + "version": "1.7.10", + "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-riscv64-musl/-/resolver-binding-linux-riscv64-musl-1.7.10.tgz", + "integrity": "sha512-A82SB6yEaA8EhIW2r0I7P+k5lg7zPscFnGs1Gna5rfPwoZjeUAGX76T55+DiyTiy08VFKUi79PGCulXnfjDq0g==", "cpu": [ "riscv64" ], @@ -2867,9 +2982,9 @@ ] }, "node_modules/@unrs/resolver-binding-linux-s390x-gnu": { - "version": "1.7.2", - "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-s390x-gnu/-/resolver-binding-linux-s390x-gnu-1.7.2.tgz", - "integrity": "sha512-KaZByo8xuQZbUhhreBTW+yUnOIHUsv04P8lKjQ5otiGoSJ17ISGYArc+4vKdLEpGaLbemGzr4ZeUbYQQsLWFjA==", + "version": "1.7.10", + "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-s390x-gnu/-/resolver-binding-linux-s390x-gnu-1.7.10.tgz", + "integrity": "sha512-J+VmOPH16U69QshCp9WS+Zuiuu9GWTISKchKIhLbS/6JSCEfw2A4N02whv2VmrkXE287xxZbhW1p6xlAXNzwqg==", "cpu": [ "s390x" ], @@ -2881,9 +2996,9 @@ ] }, "node_modules/@unrs/resolver-binding-linux-x64-gnu": { - "version": "1.7.2", - "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-x64-gnu/-/resolver-binding-linux-x64-gnu-1.7.2.tgz", - "integrity": "sha512-dEidzJDubxxhUCBJ/SHSMJD/9q7JkyfBMT77Px1npl4xpg9t0POLvnWywSk66BgZS/b2Hy9Y1yFaoMTFJUe9yg==", + "version": "1.7.10", + "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-x64-gnu/-/resolver-binding-linux-x64-gnu-1.7.10.tgz", + "integrity": "sha512-bYTdDltcB/V3fEqpx8YDwDw8ta9uEg8TUbJOtek6JM42u9ciJ7R/jBjNeAOs+QbyxGDd2d6xkBaGwty1HzOz3Q==", "cpu": [ "x64" ], @@ -2895,9 +3010,9 @@ ] }, "node_modules/@unrs/resolver-binding-linux-x64-musl": { - "version": "1.7.2", - "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-x64-musl/-/resolver-binding-linux-x64-musl-1.7.2.tgz", - "integrity": "sha512-RvP+Ux3wDjmnZDT4XWFfNBRVG0fMsc+yVzNFUqOflnDfZ9OYujv6nkh+GOr+watwrW4wdp6ASfG/e7bkDradsw==", + "version": "1.7.10", + "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-x64-musl/-/resolver-binding-linux-x64-musl-1.7.10.tgz", + "integrity": "sha512-NYZ1GvSuTokJ28lqcjrMTnGMySoo4dVcNK/nsNCKCXT++1zekZtJaE+N+4jc1kR7EV0fc1OhRrOGcSt7FT9t8w==", "cpu": [ "x64" ], @@ -2909,9 +3024,9 @@ ] }, "node_modules/@unrs/resolver-binding-wasm32-wasi": { - "version": "1.7.2", - "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-wasm32-wasi/-/resolver-binding-wasm32-wasi-1.7.2.tgz", - "integrity": "sha512-y797JBmO9IsvXVRCKDXOxjyAE4+CcZpla2GSoBQ33TVb3ILXuFnMrbR/QQZoauBYeOFuu4w3ifWLw52sdHGz6g==", + "version": "1.7.10", + "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-wasm32-wasi/-/resolver-binding-wasm32-wasi-1.7.10.tgz", + "integrity": "sha512-MRjJhTaQzLoX8OtzRBQDJ84OJ8IX1FqpRAUSxp/JtPeak+fyDfhXaEjcA/fhfgrACUnvC+jWC52f/V6MixSKCQ==", "cpu": [ "wasm32" ], @@ -2919,16 +3034,16 @@ "license": "MIT", "optional": true, "dependencies": { - "@napi-rs/wasm-runtime": "^0.2.9" + "@napi-rs/wasm-runtime": "^0.2.10" }, "engines": { "node": ">=14.0.0" } }, "node_modules/@unrs/resolver-binding-win32-arm64-msvc": { - "version": "1.7.2", - "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-win32-arm64-msvc/-/resolver-binding-win32-arm64-msvc-1.7.2.tgz", - "integrity": "sha512-gtYTh4/VREVSLA+gHrfbWxaMO/00y+34htY7XpioBTy56YN2eBjkPrY1ML1Zys89X3RJDKVaogzwxlM1qU7egg==", + "version": "1.7.10", + "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-win32-arm64-msvc/-/resolver-binding-win32-arm64-msvc-1.7.10.tgz", + "integrity": "sha512-Cgw6qhdsfzXJnHb006CzqgaX8mD445x5FGKuueaLeH1ptCxDbzRs8wDm6VieOI7rdbstfYBaFtaYN7zBT5CUPg==", "cpu": [ "arm64" ], @@ -2940,9 +3055,9 @@ ] }, "node_modules/@unrs/resolver-binding-win32-ia32-msvc": { - "version": "1.7.2", - "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-win32-ia32-msvc/-/resolver-binding-win32-ia32-msvc-1.7.2.tgz", - "integrity": "sha512-Ywv20XHvHTDRQs12jd3MY8X5C8KLjDbg/jyaal/QLKx3fAShhJyD4blEANInsjxW3P7isHx1Blt56iUDDJO3jg==", + "version": "1.7.10", + "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-win32-ia32-msvc/-/resolver-binding-win32-ia32-msvc-1.7.10.tgz", + "integrity": "sha512-Z7oECyIT2/HsrWpJ6wi2b+lVbPmWqQHuW5zeatafoRXizk1+2wUl+aSop1PF58XcyBuwPP2YpEUUpMZ8ILV4fA==", "cpu": [ "ia32" ], @@ -2954,9 +3069,9 @@ ] }, "node_modules/@unrs/resolver-binding-win32-x64-msvc": { - "version": "1.7.2", - "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-win32-x64-msvc/-/resolver-binding-win32-x64-msvc-1.7.2.tgz", - "integrity": "sha512-friS8NEQfHaDbkThxopGk+LuE5v3iY0StruifjQEt7SLbA46OnfgMO15sOTkbpJkol6RB+1l1TYPXh0sCddpvA==", + "version": "1.7.10", + "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-win32-x64-msvc/-/resolver-binding-win32-x64-msvc-1.7.10.tgz", + "integrity": "sha512-DGAOo5asNvDsmFgwkb7xsgxNyN0If6XFYwDIC1QlRE7kEYWIMRChtWJyHDf30XmGovDNOs/37krxhnga/nm/4w==", "cpu": [ "x64" ], @@ -3070,9 +3185,9 @@ "license": "Python-2.0" }, "node_modules/aria-hidden": { - "version": "1.2.4", - "resolved": "https://registry.npmjs.org/aria-hidden/-/aria-hidden-1.2.4.tgz", - "integrity": "sha512-y+CcFFwelSXpLZk/7fMB2mUbGtX9lKycf1MWJ7CaTIERyitVlyQx6C+sxcROU2BAJ24OiZyK+8wj2i8AlBoS3A==", + "version": "1.2.6", + "resolved": "https://registry.npmjs.org/aria-hidden/-/aria-hidden-1.2.6.tgz", + "integrity": "sha512-ik3ZgC9dY/lYVVM++OISsaYDeg1tb0VtP5uL3ouh1koGOaUMDPpbFIei4JkFimWUFPn90sbMNMXQAIVOlnYKJA==", "license": "MIT", "dependencies": { "tslib": "^2.0.0" @@ -3109,18 +3224,20 @@ } }, "node_modules/array-includes": { - "version": "3.1.8", - "resolved": "https://registry.npmjs.org/array-includes/-/array-includes-3.1.8.tgz", - "integrity": "sha512-itaWrbYbqpGXkGhZPGUulwnhVf5Hpy1xiCFsGqyIGglbBxmG5vSjxQen3/WGOjPpNEv1RtBLKxbmVXm8HpJStQ==", + "version": "3.1.9", + "resolved": "https://registry.npmjs.org/array-includes/-/array-includes-3.1.9.tgz", + "integrity": "sha512-FmeCCAenzH0KH381SPT5FZmiA/TmpndpcaShhfgEN9eCVjnFBqq3l1xrI42y8+PPLI6hypzou4GXw00WHmPBLQ==", "dev": true, "license": "MIT", "dependencies": { - "call-bind": "^1.0.7", + "call-bind": "^1.0.8", + "call-bound": "^1.0.4", "define-properties": "^1.2.1", - "es-abstract": "^1.23.2", - "es-object-atoms": "^1.0.0", - "get-intrinsic": "^1.2.4", - "is-string": "^1.0.7" + "es-abstract": "^1.24.0", + "es-object-atoms": "^1.1.1", + "get-intrinsic": "^1.3.0", + "is-string": "^1.1.1", + "math-intrinsics": "^1.1.0" }, "engines": { "node": ">= 0.4" @@ -3431,9 +3548,9 @@ } }, "node_modules/caniuse-lite": { - "version": "1.0.30001715", - "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001715.tgz", - "integrity": "sha512-7ptkFGMm2OAOgvZpwgA4yjQ5SQbrNVGdRjzH0pBdy1Fasvcr+KAeECmbCAECzTuDuoX0FCY8KzUxjf9+9kfZEw==", + "version": "1.0.30001721", + "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001721.tgz", + "integrity": "sha512-cOuvmUVtKrtEaoKiO0rSc29jcjwMwX5tOHDy4MgVFEWiUXj4uBMJkwI8MDySkgXidpMiHUcviogAvFi4pA2hDQ==", "funding": [ { "type": "opencollective", @@ -3467,6 +3584,16 @@ "url": "https://github.com/chalk/chalk?sponsor=1" } }, + "node_modules/chownr": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/chownr/-/chownr-3.0.0.tgz", + "integrity": "sha512-+IxzY9BZOQd/XuYPRmrvEVjF/nqj5kgT4kEq7VofrDoM1MxoRjEWkrCC3EtLi59TVawxTAn+orJwFQcrqEN1+g==", + "dev": true, + "license": "BlueOak-1.0.0", + "engines": { + "node": ">=18" + } + }, "node_modules/class-variance-authority": { "version": "0.7.1", "resolved": "https://registry.npmjs.org/class-variance-authority/-/class-variance-authority-0.7.1.tgz", @@ -3819,9 +3946,9 @@ } }, "node_modules/debug": { - "version": "4.4.0", - "resolved": "https://registry.npmjs.org/debug/-/debug-4.4.0.tgz", - "integrity": "sha512-6WTZ/IxCY/T6BALoZHaE4ctp9xm+Z5kY/pzYaCHRFeyVhojxlrm+46y68HA6hr0TcwEssoxNiDEUJQjfPZ/RYA==", + "version": "4.4.1", + "resolved": "https://registry.npmjs.org/debug/-/debug-4.4.1.tgz", + "integrity": "sha512-KcKCqiftBJcZr++7ykoDIEwSa3XWowTfNPo92BYxjXiyYEVrUQh2aLyhxBCwww+heortUFxEJYcRzosstTEBYQ==", "dev": true, "license": "MIT", "dependencies": { @@ -3995,9 +4122,9 @@ } }, "node_modules/es-abstract": { - "version": "1.23.9", - "resolved": "https://registry.npmjs.org/es-abstract/-/es-abstract-1.23.9.tgz", - "integrity": "sha512-py07lI0wjxAC/DcfK1S6G7iANonniZwTISvdPzk9hzeH0IZIshbuuFxLIU96OyF89Yb9hiqWn8M/bY83KY5vzA==", + "version": "1.24.0", + "resolved": "https://registry.npmjs.org/es-abstract/-/es-abstract-1.24.0.tgz", + "integrity": "sha512-WSzPgsdLtTcQwm4CROfS5ju2Wa1QQcVeT37jFjYzdFz1r9ahadC8B8/a4qxJxM+09F18iumCdRmlr96ZYkQvEg==", "dev": true, "license": "MIT", "dependencies": { @@ -4005,18 +4132,18 @@ "arraybuffer.prototype.slice": "^1.0.4", "available-typed-arrays": "^1.0.7", "call-bind": "^1.0.8", - "call-bound": "^1.0.3", + "call-bound": "^1.0.4", "data-view-buffer": "^1.0.2", "data-view-byte-length": "^1.0.2", "data-view-byte-offset": "^1.0.1", "es-define-property": "^1.0.1", "es-errors": "^1.3.0", - "es-object-atoms": "^1.0.0", + "es-object-atoms": "^1.1.1", "es-set-tostringtag": "^2.1.0", "es-to-primitive": "^1.3.0", "function.prototype.name": "^1.1.8", - "get-intrinsic": "^1.2.7", - "get-proto": "^1.0.0", + "get-intrinsic": "^1.3.0", + "get-proto": "^1.0.1", "get-symbol-description": "^1.1.0", "globalthis": "^1.0.4", "gopd": "^1.2.0", @@ -4028,21 +4155,24 @@ "is-array-buffer": "^3.0.5", "is-callable": "^1.2.7", "is-data-view": "^1.0.2", + "is-negative-zero": "^2.0.3", "is-regex": "^1.2.1", + "is-set": "^2.0.3", "is-shared-array-buffer": "^1.0.4", "is-string": "^1.1.1", "is-typed-array": "^1.1.15", - "is-weakref": "^1.1.0", + "is-weakref": "^1.1.1", "math-intrinsics": "^1.1.0", - "object-inspect": "^1.13.3", + "object-inspect": "^1.13.4", "object-keys": "^1.1.1", "object.assign": "^4.1.7", "own-keys": "^1.0.1", - "regexp.prototype.flags": "^1.5.3", + "regexp.prototype.flags": "^1.5.4", "safe-array-concat": "^1.1.3", "safe-push-apply": "^1.0.0", "safe-regex-test": "^1.1.0", "set-proto": "^1.0.0", + "stop-iteration-iterator": "^1.1.0", "string.prototype.trim": "^1.2.10", "string.prototype.trimend": "^1.0.9", "string.prototype.trimstart": "^1.0.8", @@ -4051,7 +4181,7 @@ "typed-array-byte-offset": "^1.0.4", "typed-array-length": "^1.0.7", "unbox-primitive": "^1.1.0", - "which-typed-array": "^1.1.18" + "which-typed-array": "^1.1.19" }, "engines": { "node": ">= 0.4" @@ -4268,6 +4398,41 @@ } } }, + "node_modules/eslint-config-next/node_modules/eslint-import-resolver-typescript": { + "version": "3.10.1", + "resolved": "https://registry.npmjs.org/eslint-import-resolver-typescript/-/eslint-import-resolver-typescript-3.10.1.tgz", + "integrity": "sha512-A1rHYb06zjMGAxdLSkN2fXPBwuSaQ0iO5M/hdyS0Ajj1VBaRp0sPD3dn1FhME3c/JluGFbwSxyCfqdSbtQLAHQ==", + "dev": true, + "license": "ISC", + "dependencies": { + "@nolyfill/is-core-module": "1.0.39", + "debug": "^4.4.0", + "get-tsconfig": "^4.10.0", + "is-bun-module": "^2.0.0", + "stable-hash": "^0.0.5", + "tinyglobby": "^0.2.13", + "unrs-resolver": "^1.6.2" + }, + "engines": { + "node": "^14.18.0 || >=16.0.0" + }, + "funding": { + "url": "https://opencollective.com/eslint-import-resolver-typescript" + }, + "peerDependencies": { + "eslint": "*", + "eslint-plugin-import": "*", + "eslint-plugin-import-x": "*" + }, + "peerDependenciesMeta": { + "eslint-plugin-import": { + "optional": true + }, + "eslint-plugin-import-x": { + "optional": true + } + } + }, "node_modules/eslint-config-prettier": { "version": "8.10.0", "resolved": "https://registry.npmjs.org/eslint-config-prettier/-/eslint-config-prettier-8.10.0.tgz", @@ -4303,41 +4468,6 @@ "ms": "^2.1.1" } }, - "node_modules/eslint-import-resolver-typescript": { - "version": "3.10.1", - "resolved": "https://registry.npmjs.org/eslint-import-resolver-typescript/-/eslint-import-resolver-typescript-3.10.1.tgz", - "integrity": "sha512-A1rHYb06zjMGAxdLSkN2fXPBwuSaQ0iO5M/hdyS0Ajj1VBaRp0sPD3dn1FhME3c/JluGFbwSxyCfqdSbtQLAHQ==", - "dev": true, - "license": "ISC", - "dependencies": { - "@nolyfill/is-core-module": "1.0.39", - "debug": "^4.4.0", - "get-tsconfig": "^4.10.0", - "is-bun-module": "^2.0.0", - "stable-hash": "^0.0.5", - "tinyglobby": "^0.2.13", - "unrs-resolver": "^1.6.2" - }, - "engines": { - "node": "^14.18.0 || >=16.0.0" - }, - "funding": { - "url": "https://opencollective.com/eslint-import-resolver-typescript" - }, - "peerDependencies": { - "eslint": "*", - "eslint-plugin-import": "*", - "eslint-plugin-import-x": "*" - }, - "peerDependenciesMeta": { - "eslint-plugin-import": { - "optional": true - }, - "eslint-plugin-import-x": { - "optional": true - } - } - }, "node_modules/eslint-module-utils": { "version": "2.12.0", "resolved": "https://registry.npmjs.org/eslint-module-utils/-/eslint-module-utils-2.12.0.tgz", @@ -4924,14 +5054,15 @@ } }, "node_modules/form-data": { - "version": "4.0.2", - "resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.2.tgz", - "integrity": "sha512-hGfm/slu0ZabnNt4oaRZ6uREyfCj6P4fT/n6A1rGV+Z0VdGXjfOhVUpkn6qVQONHGIFwmveGXyDs75+nr6FM8w==", + "version": "4.0.3", + "resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.3.tgz", + "integrity": "sha512-qsITQPfmvMOSAdeyZ+12I1c+CKSstAFAwu+97zrnWAbIr5u8wfsExUzCesVLC8NgHuRUqNN4Zy6UPWUTRGslcA==", "license": "MIT", "dependencies": { "asynckit": "^0.4.0", "combined-stream": "^1.0.8", "es-set-tostringtag": "^2.1.0", + "hasown": "^2.0.2", "mime-types": "^2.1.12" }, "engines": { @@ -5103,9 +5234,9 @@ } }, "node_modules/get-tsconfig": { - "version": "4.10.0", - "resolved": "https://registry.npmjs.org/get-tsconfig/-/get-tsconfig-4.10.0.tgz", - "integrity": "sha512-kGzZ3LWWQcGIAmg6iWvXn0ei6WDtV26wzHRMwDSzmAbcXrTEXxHy6IehI6/4eT6VRKyMP1eF1VqwrVUmE/LR7A==", + "version": "4.10.1", + "resolved": "https://registry.npmjs.org/get-tsconfig/-/get-tsconfig-4.10.1.tgz", + "integrity": "sha512-auHyJ4AgMz7vgS8Hp3N6HXSmlMdUyhSUrfBF16w153rxtLIEOE+HGqaBppczZvnHLqQJfiHotCYpNhl0lUROFQ==", "dev": true, "license": "MIT", "dependencies": { @@ -5682,6 +5813,19 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/is-negative-zero": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/is-negative-zero/-/is-negative-zero-2.0.3.tgz", + "integrity": "sha512-5KoIu2Ngpyek75jXodFvnafB6DJgr3u8uuK0LEZJjrU19DrMD3EVERaR8sjz8CCGgpZvxPl9SuE1GMVPFHx1mw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, "node_modules/is-number": { "version": "7.0.0", "resolved": "https://registry.npmjs.org/is-number/-/is-number-7.0.0.tgz", @@ -6058,9 +6202,9 @@ } }, "node_modules/lightningcss": { - "version": "1.29.2", - "resolved": "https://registry.npmjs.org/lightningcss/-/lightningcss-1.29.2.tgz", - "integrity": "sha512-6b6gd/RUXKaw5keVdSEtqFVdzWnU5jMxTUjA2bVcMNPLwSQ08Sv/UodBVtETLCn7k4S1Ibxwh7k68IwLZPgKaA==", + "version": "1.30.1", + "resolved": "https://registry.npmjs.org/lightningcss/-/lightningcss-1.30.1.tgz", + "integrity": "sha512-xi6IyHML+c9+Q3W0S4fCQJOym42pyurFiJUHEcEyHS0CeKzia4yZDEsLlqOFykxOdHpNy0NmvVO31vcSqAxJCg==", "dev": true, "license": "MPL-2.0", "dependencies": { @@ -6074,22 +6218,22 @@ "url": "https://opencollective.com/parcel" }, "optionalDependencies": { - "lightningcss-darwin-arm64": "1.29.2", - "lightningcss-darwin-x64": "1.29.2", - "lightningcss-freebsd-x64": "1.29.2", - "lightningcss-linux-arm-gnueabihf": "1.29.2", - "lightningcss-linux-arm64-gnu": "1.29.2", - "lightningcss-linux-arm64-musl": "1.29.2", - "lightningcss-linux-x64-gnu": "1.29.2", - "lightningcss-linux-x64-musl": "1.29.2", - "lightningcss-win32-arm64-msvc": "1.29.2", - "lightningcss-win32-x64-msvc": "1.29.2" + "lightningcss-darwin-arm64": "1.30.1", + "lightningcss-darwin-x64": "1.30.1", + "lightningcss-freebsd-x64": "1.30.1", + "lightningcss-linux-arm-gnueabihf": "1.30.1", + "lightningcss-linux-arm64-gnu": "1.30.1", + "lightningcss-linux-arm64-musl": "1.30.1", + "lightningcss-linux-x64-gnu": "1.30.1", + "lightningcss-linux-x64-musl": "1.30.1", + "lightningcss-win32-arm64-msvc": "1.30.1", + "lightningcss-win32-x64-msvc": "1.30.1" } }, "node_modules/lightningcss-darwin-arm64": { - "version": "1.29.2", - "resolved": "https://registry.npmjs.org/lightningcss-darwin-arm64/-/lightningcss-darwin-arm64-1.29.2.tgz", - "integrity": "sha512-cK/eMabSViKn/PG8U/a7aCorpeKLMlK0bQeNHmdb7qUnBkNPnL+oV5DjJUo0kqWsJUapZsM4jCfYItbqBDvlcA==", + "version": "1.30.1", + "resolved": "https://registry.npmjs.org/lightningcss-darwin-arm64/-/lightningcss-darwin-arm64-1.30.1.tgz", + "integrity": "sha512-c8JK7hyE65X1MHMN+Viq9n11RRC7hgin3HhYKhrMyaXflk5GVplZ60IxyoVtzILeKr+xAJwg6zK6sjTBJ0FKYQ==", "cpu": [ "arm64" ], @@ -6108,9 +6252,9 @@ } }, "node_modules/lightningcss-darwin-x64": { - "version": "1.29.2", - "resolved": "https://registry.npmjs.org/lightningcss-darwin-x64/-/lightningcss-darwin-x64-1.29.2.tgz", - "integrity": "sha512-j5qYxamyQw4kDXX5hnnCKMf3mLlHvG44f24Qyi2965/Ycz829MYqjrVg2H8BidybHBp9kom4D7DR5VqCKDXS0w==", + "version": "1.30.1", + "resolved": "https://registry.npmjs.org/lightningcss-darwin-x64/-/lightningcss-darwin-x64-1.30.1.tgz", + "integrity": "sha512-k1EvjakfumAQoTfcXUcHQZhSpLlkAuEkdMBsI/ivWw9hL+7FtilQc0Cy3hrx0AAQrVtQAbMI7YjCgYgvn37PzA==", "cpu": [ "x64" ], @@ -6129,9 +6273,9 @@ } }, "node_modules/lightningcss-freebsd-x64": { - "version": "1.29.2", - "resolved": "https://registry.npmjs.org/lightningcss-freebsd-x64/-/lightningcss-freebsd-x64-1.29.2.tgz", - "integrity": "sha512-wDk7M2tM78Ii8ek9YjnY8MjV5f5JN2qNVO+/0BAGZRvXKtQrBC4/cn4ssQIpKIPP44YXw6gFdpUF+Ps+RGsCwg==", + "version": "1.30.1", + "resolved": "https://registry.npmjs.org/lightningcss-freebsd-x64/-/lightningcss-freebsd-x64-1.30.1.tgz", + "integrity": "sha512-kmW6UGCGg2PcyUE59K5r0kWfKPAVy4SltVeut+umLCFoJ53RdCUWxcRDzO1eTaxf/7Q2H7LTquFHPL5R+Gjyig==", "cpu": [ "x64" ], @@ -6150,9 +6294,9 @@ } }, "node_modules/lightningcss-linux-arm-gnueabihf": { - "version": "1.29.2", - "resolved": "https://registry.npmjs.org/lightningcss-linux-arm-gnueabihf/-/lightningcss-linux-arm-gnueabihf-1.29.2.tgz", - "integrity": "sha512-IRUrOrAF2Z+KExdExe3Rz7NSTuuJ2HvCGlMKoquK5pjvo2JY4Rybr+NrKnq0U0hZnx5AnGsuFHjGnNT14w26sg==", + "version": "1.30.1", + "resolved": "https://registry.npmjs.org/lightningcss-linux-arm-gnueabihf/-/lightningcss-linux-arm-gnueabihf-1.30.1.tgz", + "integrity": "sha512-MjxUShl1v8pit+6D/zSPq9S9dQ2NPFSQwGvxBCYaBYLPlCWuPh9/t1MRS8iUaR8i+a6w7aps+B4N0S1TYP/R+Q==", "cpu": [ "arm" ], @@ -6171,9 +6315,9 @@ } }, "node_modules/lightningcss-linux-arm64-gnu": { - "version": "1.29.2", - "resolved": "https://registry.npmjs.org/lightningcss-linux-arm64-gnu/-/lightningcss-linux-arm64-gnu-1.29.2.tgz", - "integrity": "sha512-KKCpOlmhdjvUTX/mBuaKemp0oeDIBBLFiU5Fnqxh1/DZ4JPZi4evEH7TKoSBFOSOV3J7iEmmBaw/8dpiUvRKlQ==", + "version": "1.30.1", + "resolved": "https://registry.npmjs.org/lightningcss-linux-arm64-gnu/-/lightningcss-linux-arm64-gnu-1.30.1.tgz", + "integrity": "sha512-gB72maP8rmrKsnKYy8XUuXi/4OctJiuQjcuqWNlJQ6jZiWqtPvqFziskH3hnajfvKB27ynbVCucKSm2rkQp4Bw==", "cpu": [ "arm64" ], @@ -6192,9 +6336,9 @@ } }, "node_modules/lightningcss-linux-arm64-musl": { - "version": "1.29.2", - "resolved": "https://registry.npmjs.org/lightningcss-linux-arm64-musl/-/lightningcss-linux-arm64-musl-1.29.2.tgz", - "integrity": "sha512-Q64eM1bPlOOUgxFmoPUefqzY1yV3ctFPE6d/Vt7WzLW4rKTv7MyYNky+FWxRpLkNASTnKQUaiMJ87zNODIrrKQ==", + "version": "1.30.1", + "resolved": "https://registry.npmjs.org/lightningcss-linux-arm64-musl/-/lightningcss-linux-arm64-musl-1.30.1.tgz", + "integrity": "sha512-jmUQVx4331m6LIX+0wUhBbmMX7TCfjF5FoOH6SD1CttzuYlGNVpA7QnrmLxrsub43ClTINfGSYyHe2HWeLl5CQ==", "cpu": [ "arm64" ], @@ -6213,9 +6357,9 @@ } }, "node_modules/lightningcss-linux-x64-gnu": { - "version": "1.29.2", - "resolved": "https://registry.npmjs.org/lightningcss-linux-x64-gnu/-/lightningcss-linux-x64-gnu-1.29.2.tgz", - "integrity": "sha512-0v6idDCPG6epLXtBH/RPkHvYx74CVziHo6TMYga8O2EiQApnUPZsbR9nFNrg2cgBzk1AYqEd95TlrsL7nYABQg==", + "version": "1.30.1", + "resolved": "https://registry.npmjs.org/lightningcss-linux-x64-gnu/-/lightningcss-linux-x64-gnu-1.30.1.tgz", + "integrity": "sha512-piWx3z4wN8J8z3+O5kO74+yr6ze/dKmPnI7vLqfSqI8bccaTGY5xiSGVIJBDd5K5BHlvVLpUB3S2YCfelyJ1bw==", "cpu": [ "x64" ], @@ -6234,9 +6378,9 @@ } }, "node_modules/lightningcss-linux-x64-musl": { - "version": "1.29.2", - "resolved": "https://registry.npmjs.org/lightningcss-linux-x64-musl/-/lightningcss-linux-x64-musl-1.29.2.tgz", - "integrity": "sha512-rMpz2yawkgGT8RULc5S4WiZopVMOFWjiItBT7aSfDX4NQav6M44rhn5hjtkKzB+wMTRlLLqxkeYEtQ3dd9696w==", + "version": "1.30.1", + "resolved": "https://registry.npmjs.org/lightningcss-linux-x64-musl/-/lightningcss-linux-x64-musl-1.30.1.tgz", + "integrity": "sha512-rRomAK7eIkL+tHY0YPxbc5Dra2gXlI63HL+v1Pdi1a3sC+tJTcFrHX+E86sulgAXeI7rSzDYhPSeHHjqFhqfeQ==", "cpu": [ "x64" ], @@ -6255,9 +6399,9 @@ } }, "node_modules/lightningcss-win32-arm64-msvc": { - "version": "1.29.2", - "resolved": "https://registry.npmjs.org/lightningcss-win32-arm64-msvc/-/lightningcss-win32-arm64-msvc-1.29.2.tgz", - "integrity": "sha512-nL7zRW6evGQqYVu/bKGK+zShyz8OVzsCotFgc7judbt6wnB2KbiKKJwBE4SGoDBQ1O94RjW4asrCjQL4i8Fhbw==", + "version": "1.30.1", + "resolved": "https://registry.npmjs.org/lightningcss-win32-arm64-msvc/-/lightningcss-win32-arm64-msvc-1.30.1.tgz", + "integrity": "sha512-mSL4rqPi4iXq5YVqzSsJgMVFENoa4nGTT/GjO2c0Yl9OuQfPsIfncvLrEW6RbbB24WtZ3xP/2CCmI3tNkNV4oA==", "cpu": [ "arm64" ], @@ -6276,9 +6420,9 @@ } }, "node_modules/lightningcss-win32-x64-msvc": { - "version": "1.29.2", - "resolved": "https://registry.npmjs.org/lightningcss-win32-x64-msvc/-/lightningcss-win32-x64-msvc-1.29.2.tgz", - "integrity": "sha512-EdIUW3B2vLuHmv7urfzMI/h2fmlnOQBk1xlsDxkN1tCWKjNFjfLhGxYk8C8mzpSfr+A6jFFIi8fU6LbQGsRWjA==", + "version": "1.30.1", + "resolved": "https://registry.npmjs.org/lightningcss-win32-x64-msvc/-/lightningcss-win32-x64-msvc-1.30.1.tgz", + "integrity": "sha512-PVqXh48wh4T53F/1CCu8PIPCxLzWyCnn/9T5W1Jpmdy5h9Cwd+0YQS6/LwhHXSafuc61/xg9Lv5OrCby6a++jg==", "cpu": [ "x64" ], @@ -6518,6 +6662,16 @@ "react": "^16.5.1 || ^17.0.0 || ^18.0.0 || ^19.0.0-rc" } }, + "node_modules/magic-string": { + "version": "0.30.17", + "resolved": "https://registry.npmjs.org/magic-string/-/magic-string-0.30.17.tgz", + "integrity": "sha512-sNPKHvyjVf7gyjwS4xGTaW/mCnF8wnjtifKBEhxfZ7E/S8tQ0rssrwGNn6q8JH/ohItJfSQp9mBtQYuTlH5QnA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@jridgewell/sourcemap-codec": "^1.5.0" + } + }, "node_modules/math-intrinsics": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/math-intrinsics/-/math-intrinsics-1.1.0.tgz", @@ -6658,6 +6812,35 @@ "node": ">=16 || 14 >=14.17" } }, + "node_modules/minizlib": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/minizlib/-/minizlib-3.0.2.tgz", + "integrity": "sha512-oG62iEk+CYt5Xj2YqI5Xi9xWUeZhDI8jjQmC5oThVH5JGCTgIjr7ciJDzC7MBzYd//WvR1OTmP5Q38Q8ShQtVA==", + "dev": true, + "license": "MIT", + "dependencies": { + "minipass": "^7.1.2" + }, + "engines": { + "node": ">= 18" + } + }, + "node_modules/mkdirp": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/mkdirp/-/mkdirp-3.0.1.tgz", + "integrity": "sha512-+NsyUUAZDmo6YVHzL/stxSu3t9YS1iljliy3BSDrXJ/dkn1KYdmtZODGGjLcc9XLgVVpH4KshHB8XmZgMhaBXg==", + "dev": true, + "license": "MIT", + "bin": { + "mkdirp": "dist/cjs/src/bin.js" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, "node_modules/motion-dom": { "version": "11.18.1", "resolved": "https://registry.npmjs.org/motion-dom/-/motion-dom-11.18.1.tgz", @@ -6699,9 +6882,9 @@ } }, "node_modules/napi-postinstall": { - "version": "0.2.2", - "resolved": "https://registry.npmjs.org/napi-postinstall/-/napi-postinstall-0.2.2.tgz", - "integrity": "sha512-Wy1VI/hpKHwy1MsnFxHCJxqFwmmxD0RA/EKPL7e6mfbsY01phM2SZyJnRdU0bLvhu0Quby1DCcAZti3ghdl4/A==", + "version": "0.2.4", + "resolved": "https://registry.npmjs.org/napi-postinstall/-/napi-postinstall-0.2.4.tgz", + "integrity": "sha512-ZEzHJwBhZ8qQSbknHqYcdtQVr8zUgGyM/q6h6qAyhtyVMNrSgDhrC4disf03dYW0e+czXyLnZINnCTEkWy0eJg==", "dev": true, "license": "MIT", "bin": { @@ -6729,12 +6912,12 @@ "license": "MIT" }, "node_modules/next": { - "version": "14.2.28", - "resolved": "https://registry.npmjs.org/next/-/next-14.2.28.tgz", - "integrity": "sha512-QLEIP/kYXynIxtcKB6vNjtWLVs3Y4Sb+EClTC/CSVzdLD1gIuItccpu/n1lhmduffI32iPGEK2cLLxxt28qgYA==", + "version": "14.2.29", + "resolved": "https://registry.npmjs.org/next/-/next-14.2.29.tgz", + "integrity": "sha512-s98mCOMOWLGGpGOfgKSnleXLuegvvH415qtRZXpSp00HeEgdmrxmwL9cgKU+h4XrhB16zEI5d/7BnkS3ATInsA==", "license": "MIT", "dependencies": { - "@next/env": "14.2.28", + "@next/env": "14.2.29", "@swc/helpers": "0.5.5", "busboy": "1.6.0", "caniuse-lite": "^1.0.30001579", @@ -6749,15 +6932,15 @@ "node": ">=18.17.0" }, "optionalDependencies": { - "@next/swc-darwin-arm64": "14.2.28", - "@next/swc-darwin-x64": "14.2.28", - "@next/swc-linux-arm64-gnu": "14.2.28", - "@next/swc-linux-arm64-musl": "14.2.28", - "@next/swc-linux-x64-gnu": "14.2.28", - "@next/swc-linux-x64-musl": "14.2.28", - "@next/swc-win32-arm64-msvc": "14.2.28", - "@next/swc-win32-ia32-msvc": "14.2.28", - "@next/swc-win32-x64-msvc": "14.2.28" + "@next/swc-darwin-arm64": "14.2.29", + "@next/swc-darwin-x64": "14.2.29", + "@next/swc-linux-arm64-gnu": "14.2.29", + "@next/swc-linux-arm64-musl": "14.2.29", + "@next/swc-linux-x64-gnu": "14.2.29", + "@next/swc-linux-x64-musl": "14.2.29", + "@next/swc-win32-arm64-msvc": "14.2.29", + "@next/swc-win32-ia32-msvc": "14.2.29", + "@next/swc-win32-x64-msvc": "14.2.29" }, "peerDependencies": { "@opentelemetry/api": "^1.1.0", @@ -7195,9 +7378,9 @@ } }, "node_modules/postcss": { - "version": "8.5.3", - "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.3.tgz", - "integrity": "sha512-dle9A3yYxlBSrt8Fu+IpjGT8SY8hN0mlaA6GY8t0P5PjIOZemULz/E2Bnm/2dcUOena75OTNkHI76uZBNUUq3A==", + "version": "8.5.4", + "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.4.tgz", + "integrity": "sha512-QSa9EBe+uwlGTFmHsPKokv3B/oEMQZxfqW0QqNCyhpa6mB1afzulwn8hihglqAb2pOw+BJgNlmXQ8la2VeHB7w==", "dev": true, "funding": [ { @@ -7215,7 +7398,7 @@ ], "license": "MIT", "dependencies": { - "nanoid": "^3.3.8", + "nanoid": "^3.3.11", "picocolors": "^1.1.1", "source-map-js": "^1.2.1" }, @@ -7311,57 +7494,58 @@ "license": "MIT" }, "node_modules/radix-ui": { - "version": "1.3.4", - "resolved": "https://registry.npmjs.org/radix-ui/-/radix-ui-1.3.4.tgz", - "integrity": "sha512-uHJD4yRGjxbEWhkVU+w9d8d+X6HUlmbesHGsE9tRWKX62FqDD3Z3hfEtVS9W+DpZAPvKSCLfz03O7un8xZT3pg==", + "version": "1.4.2", + "resolved": "https://registry.npmjs.org/radix-ui/-/radix-ui-1.4.2.tgz", + "integrity": "sha512-fT/3YFPJzf2WUpqDoQi005GS8EpCi+53VhcLaHUj5fwkPYiZAjk1mSxFvbMA8Uq71L03n+WysuYC+mlKkXxt/Q==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", - "@radix-ui/react-accessible-icon": "1.1.4", - "@radix-ui/react-accordion": "1.2.8", - "@radix-ui/react-alert-dialog": "1.1.11", - "@radix-ui/react-arrow": "1.1.4", - "@radix-ui/react-aspect-ratio": "1.1.4", - "@radix-ui/react-avatar": "1.1.7", - "@radix-ui/react-checkbox": "1.2.3", - "@radix-ui/react-collapsible": "1.1.8", - "@radix-ui/react-collection": "1.1.4", + "@radix-ui/react-accessible-icon": "1.1.7", + "@radix-ui/react-accordion": "1.2.11", + "@radix-ui/react-alert-dialog": "1.1.14", + "@radix-ui/react-arrow": "1.1.7", + "@radix-ui/react-aspect-ratio": "1.1.7", + "@radix-ui/react-avatar": "1.1.10", + "@radix-ui/react-checkbox": "1.3.2", + "@radix-ui/react-collapsible": "1.1.11", + "@radix-ui/react-collection": "1.1.7", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", - "@radix-ui/react-context-menu": "2.2.12", - "@radix-ui/react-dialog": "1.1.11", + "@radix-ui/react-context-menu": "2.2.15", + "@radix-ui/react-dialog": "1.1.14", "@radix-ui/react-direction": "1.1.1", - "@radix-ui/react-dismissable-layer": "1.1.7", - "@radix-ui/react-dropdown-menu": "2.1.12", + "@radix-ui/react-dismissable-layer": "1.1.10", + "@radix-ui/react-dropdown-menu": "2.1.15", "@radix-ui/react-focus-guards": "1.1.2", - "@radix-ui/react-focus-scope": "1.1.4", - "@radix-ui/react-form": "0.1.4", - "@radix-ui/react-hover-card": "1.1.11", - "@radix-ui/react-label": "2.1.4", - "@radix-ui/react-menu": "2.1.12", - "@radix-ui/react-menubar": "1.1.12", - "@radix-ui/react-navigation-menu": "1.2.10", - "@radix-ui/react-one-time-password-field": "0.1.4", - "@radix-ui/react-popover": "1.1.11", - "@radix-ui/react-popper": "1.2.4", - "@radix-ui/react-portal": "1.1.6", + "@radix-ui/react-focus-scope": "1.1.7", + "@radix-ui/react-form": "0.1.7", + "@radix-ui/react-hover-card": "1.1.14", + "@radix-ui/react-label": "2.1.7", + "@radix-ui/react-menu": "2.1.15", + "@radix-ui/react-menubar": "1.1.15", + "@radix-ui/react-navigation-menu": "1.2.13", + "@radix-ui/react-one-time-password-field": "0.1.7", + "@radix-ui/react-password-toggle-field": "0.1.2", + "@radix-ui/react-popover": "1.1.14", + "@radix-ui/react-popper": "1.2.7", + "@radix-ui/react-portal": "1.1.9", "@radix-ui/react-presence": "1.1.4", - "@radix-ui/react-primitive": "2.1.0", - "@radix-ui/react-progress": "1.1.4", - "@radix-ui/react-radio-group": "1.3.4", - "@radix-ui/react-roving-focus": "1.1.7", - "@radix-ui/react-scroll-area": "1.2.6", - "@radix-ui/react-select": "2.2.2", - "@radix-ui/react-separator": "1.1.4", - "@radix-ui/react-slider": "1.3.2", - "@radix-ui/react-slot": "1.2.0", - "@radix-ui/react-switch": "1.2.2", - "@radix-ui/react-tabs": "1.1.9", - "@radix-ui/react-toast": "1.2.11", - "@radix-ui/react-toggle": "1.1.6", - "@radix-ui/react-toggle-group": "1.1.7", - "@radix-ui/react-toolbar": "1.1.7", - "@radix-ui/react-tooltip": "1.2.4", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-progress": "1.1.7", + "@radix-ui/react-radio-group": "1.3.7", + "@radix-ui/react-roving-focus": "1.1.10", + "@radix-ui/react-scroll-area": "1.2.9", + "@radix-ui/react-select": "2.2.5", + "@radix-ui/react-separator": "1.1.7", + "@radix-ui/react-slider": "1.3.5", + "@radix-ui/react-slot": "1.2.3", + "@radix-ui/react-switch": "1.2.5", + "@radix-ui/react-tabs": "1.1.12", + "@radix-ui/react-toast": "1.2.14", + "@radix-ui/react-toggle": "1.1.9", + "@radix-ui/react-toggle-group": "1.1.10", + "@radix-ui/react-toolbar": "1.1.10", + "@radix-ui/react-tooltip": "1.2.7", "@radix-ui/react-use-callback-ref": "1.1.1", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-effect-event": "0.0.2", @@ -7369,7 +7553,7 @@ "@radix-ui/react-use-is-hydrated": "0.1.0", "@radix-ui/react-use-layout-effect": "1.1.1", "@radix-ui/react-use-size": "1.1.1", - "@radix-ui/react-visually-hidden": "1.2.0" + "@radix-ui/react-visually-hidden": "1.2.3" }, "peerDependencies": { "@types/react": "*", @@ -7412,9 +7596,9 @@ } }, "node_modules/react-hook-form": { - "version": "7.56.1", - "resolved": "https://registry.npmjs.org/react-hook-form/-/react-hook-form-7.56.1.tgz", - "integrity": "sha512-qWAVokhSpshhcEuQDSANHx3jiAEFzu2HAaaQIzi/r9FNPm1ioAvuJSD4EuZzWd7Al7nTRKcKPnBKO7sRn+zavQ==", + "version": "7.57.0", + "resolved": "https://registry.npmjs.org/react-hook-form/-/react-hook-form-7.57.0.tgz", + "integrity": "sha512-RbEks3+cbvTP84l/VXGUZ+JMrKOS8ykQCRYdm5aYsxnDquL0vspsyNhGRO7pcH6hsZqWlPOjLye7rJqdtdAmlg==", "license": "MIT", "engines": { "node": ">=18.0.0" @@ -7434,9 +7618,9 @@ "license": "MIT" }, "node_modules/react-remove-scroll": { - "version": "2.6.3", - "resolved": "https://registry.npmjs.org/react-remove-scroll/-/react-remove-scroll-2.6.3.tgz", - "integrity": "sha512-pnAi91oOk8g8ABQKGF5/M9qxmmOPxaAnopyTHYfqYEwJhyFrbbBtHuSgtKEoH0jpcxx5o3hXqH1mNd9/Oi+8iQ==", + "version": "2.7.1", + "resolved": "https://registry.npmjs.org/react-remove-scroll/-/react-remove-scroll-2.7.1.tgz", + "integrity": "sha512-HpMh8+oahmIdOuS5aFKKY6Pyog+FNaZV/XyJOq7b4YFwsFHe5yYfdbIalI4k3vU2nSDql7YskmUseHsRrJqIPA==", "license": "MIT", "dependencies": { "react-remove-scroll-bar": "^2.3.7", @@ -7600,12 +7784,6 @@ "url": "https://github.com/sponsors/ljharb" } }, - "node_modules/regenerator-runtime": { - "version": "0.14.1", - "resolved": "https://registry.npmjs.org/regenerator-runtime/-/regenerator-runtime-0.14.1.tgz", - "integrity": "sha512-dYnhHh0nJoMfnkZs6GmmhFknAGRrLznOu5nc9ML+EJxGvrx6H7teuevqVqCuPcPK//3eDrrjQhehXVx9cnkGdw==", - "license": "MIT" - }, "node_modules/regexp.prototype.flags": { "version": "1.5.4", "resolved": "https://registry.npmjs.org/regexp.prototype.flags/-/regexp.prototype.flags-1.5.4.tgz", @@ -7870,9 +8048,9 @@ "license": "MIT" }, "node_modules/semver": { - "version": "7.7.1", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.7.1.tgz", - "integrity": "sha512-hlq8tAfn0m/61p4BVRcPzIGr6LKiMwo4VM6dGi6pt4qcRkmNzTcWq6eCEjEh+qXjkMDvPlOFFSGwQjoEa6gyMA==", + "version": "7.7.2", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.7.2.tgz", + "integrity": "sha512-RF0Fw+rO5AMf9MAyaRXI4AV0Ulj5lMHqVxxdSgiVbixSCXoEmmX/jk0CuJw4+3SqroYO9VoUh+HcuJivvtJemA==", "dev": true, "license": "ISC", "bin": { @@ -8099,6 +8277,20 @@ "dev": true, "license": "MIT" }, + "node_modules/stop-iteration-iterator": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/stop-iteration-iterator/-/stop-iteration-iterator-1.1.0.tgz", + "integrity": "sha512-eLoXW/DHyl62zxY4SCaIgnRhuMr6ri4juEYARS8E6sCEqzKpOiE521Ucofdx+KnDZl5xmvGYaaKCk5FEOxJCoQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "internal-slot": "^1.1.0" + }, + "engines": { + "node": ">= 0.4" + } + }, "node_modules/streamsearch": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/streamsearch/-/streamsearch-1.1.0.tgz", @@ -8433,9 +8625,9 @@ } }, "node_modules/tailwindcss": { - "version": "4.1.4", - "resolved": "https://registry.npmjs.org/tailwindcss/-/tailwindcss-4.1.4.tgz", - "integrity": "sha512-1ZIUqtPITFbv/DxRmDr5/agPqJwF69d24m9qmM1939TJehgY539CtzeZRjbLt5G6fSy/7YqqYsfvoTEw9xUI2A==", + "version": "4.1.8", + "resolved": "https://registry.npmjs.org/tailwindcss/-/tailwindcss-4.1.8.tgz", + "integrity": "sha512-kjeW8gjdxasbmFKpVGrGd5T4i40mV5J2Rasw48QARfYeQ8YS9x02ON9SFWax3Qf616rt4Cp3nVNIj6Hd1mP3og==", "dev": true, "license": "MIT" }, @@ -8450,15 +8642,33 @@ } }, "node_modules/tapable": { - "version": "2.2.1", - "resolved": "https://registry.npmjs.org/tapable/-/tapable-2.2.1.tgz", - "integrity": "sha512-GNzQvQTOIP6RyTfE2Qxb8ZVlNmw0n88vp1szwWRimP02mnTsx3Wtn5qRdqY9w2XduFNUgvOwhNnQsjwCp+kqaQ==", + "version": "2.2.2", + "resolved": "https://registry.npmjs.org/tapable/-/tapable-2.2.2.tgz", + "integrity": "sha512-Re10+NauLTMCudc7T5WLFLAwDhQ0JWdrMK+9B2M8zR5hRExKmsRDCBA7/aV/pNJFltmBFO5BAMlQFi/vq3nKOg==", "dev": true, "license": "MIT", "engines": { "node": ">=6" } }, + "node_modules/tar": { + "version": "7.4.3", + "resolved": "https://registry.npmjs.org/tar/-/tar-7.4.3.tgz", + "integrity": "sha512-5S7Va8hKfV7W5U6g3aYxXmlPoZVAwUMy9AOKyF2fVuZa2UD3qZjg578OrLRt8PcNN1PleVaL/5/yYATNL0ICUw==", + "dev": true, + "license": "ISC", + "dependencies": { + "@isaacs/fs-minipass": "^4.0.0", + "chownr": "^3.0.0", + "minipass": "^7.1.2", + "minizlib": "^3.0.1", + "mkdirp": "^3.0.1", + "yallist": "^5.0.0" + }, + "engines": { + "node": ">=18" + } + }, "node_modules/text-table": { "version": "0.2.0", "resolved": "https://registry.npmjs.org/text-table/-/text-table-0.2.0.tgz", @@ -8479,9 +8689,9 @@ "license": "MIT" }, "node_modules/tinyglobby": { - "version": "0.2.13", - "resolved": "https://registry.npmjs.org/tinyglobby/-/tinyglobby-0.2.13.tgz", - "integrity": "sha512-mEwzpUgrLySlveBwEVDMKk5B57bhLPYovRfPAXD5gA/98Opn0rCDj3GtLwFvCvH5RK9uPCExUROW5NjDwvqkxw==", + "version": "0.2.14", + "resolved": "https://registry.npmjs.org/tinyglobby/-/tinyglobby-0.2.14.tgz", + "integrity": "sha512-tX5e7OM1HnYr2+a2C/4V0htOcSQcoSTH9KgJnVvNm5zm/cyEWKJ7j7YutsH9CxMdtOkkLFy2AHrMci9IM8IPZQ==", "dev": true, "license": "MIT", "dependencies": { @@ -8496,9 +8706,9 @@ } }, "node_modules/tinyglobby/node_modules/fdir": { - "version": "6.4.4", - "resolved": "https://registry.npmjs.org/fdir/-/fdir-6.4.4.tgz", - "integrity": "sha512-1NZP+GK4GfuAv3PqKvxQRDMjdSRZjnkq7KfhlNrCNNlZ0ygQFpebfrnfnq/W7fpUnAv9aGWmY1zKx7FYL3gwhg==", + "version": "6.4.5", + "resolved": "https://registry.npmjs.org/fdir/-/fdir-6.4.5.tgz", + "integrity": "sha512-4BG7puHpVsIYxZUbiUE3RqGloLaSSwzYie5jvasC4LWuBWzZawynvYouhjbQKw2JuIGYdm0DzIxl8iVidKlUEw==", "dev": true, "license": "MIT", "peerDependencies": { @@ -8579,9 +8789,9 @@ "license": "0BSD" }, "node_modules/tw-animate-css": { - "version": "1.2.8", - "resolved": "https://registry.npmjs.org/tw-animate-css/-/tw-animate-css-1.2.8.tgz", - "integrity": "sha512-AxSnYRvyFnAiZCUndS3zQZhNfV/B77ZhJ+O7d3K6wfg/jKJY+yv6ahuyXwnyaYA9UdLqnpCwhTRv9pPTBnPR2g==", + "version": "1.3.4", + "resolved": "https://registry.npmjs.org/tw-animate-css/-/tw-animate-css-1.3.4.tgz", + "integrity": "sha512-dd1Ht6/YQHcNbq0znIT6dG8uhO7Ce+VIIhZUhjsryXsMPJQz3bZg7Q2eNzLwipb25bRZslGb2myio5mScd1TFg==", "dev": true, "license": "MIT", "funding": { @@ -8742,9 +8952,9 @@ "license": "MIT" }, "node_modules/unrs-resolver": { - "version": "1.7.2", - "resolved": "https://registry.npmjs.org/unrs-resolver/-/unrs-resolver-1.7.2.tgz", - "integrity": "sha512-BBKpaylOW8KbHsu378Zky/dGh4ckT/4NW/0SHRABdqRLcQJ2dAOjDo9g97p04sWflm0kqPqpUatxReNV/dqI5A==", + "version": "1.7.10", + "resolved": "https://registry.npmjs.org/unrs-resolver/-/unrs-resolver-1.7.10.tgz", + "integrity": "sha512-CJEMJcz6vuwRK6xxWc+uf8AGi0OyfoVtHs5mExtNecS0HZq3a3Br1JC/InwwTn6uy+qkAdAdK+nJUYO9FPtgZw==", "dev": true, "hasInstallScript": true, "license": "MIT", @@ -8752,26 +8962,26 @@ "napi-postinstall": "^0.2.2" }, "funding": { - "url": "https://github.com/sponsors/JounQin" + "url": "https://opencollective.com/unrs-resolver" }, "optionalDependencies": { - "@unrs/resolver-binding-darwin-arm64": "1.7.2", - "@unrs/resolver-binding-darwin-x64": "1.7.2", - "@unrs/resolver-binding-freebsd-x64": "1.7.2", - "@unrs/resolver-binding-linux-arm-gnueabihf": "1.7.2", - "@unrs/resolver-binding-linux-arm-musleabihf": "1.7.2", - "@unrs/resolver-binding-linux-arm64-gnu": "1.7.2", - "@unrs/resolver-binding-linux-arm64-musl": "1.7.2", - "@unrs/resolver-binding-linux-ppc64-gnu": "1.7.2", - "@unrs/resolver-binding-linux-riscv64-gnu": "1.7.2", - "@unrs/resolver-binding-linux-riscv64-musl": "1.7.2", - "@unrs/resolver-binding-linux-s390x-gnu": "1.7.2", - "@unrs/resolver-binding-linux-x64-gnu": "1.7.2", - "@unrs/resolver-binding-linux-x64-musl": "1.7.2", - "@unrs/resolver-binding-wasm32-wasi": "1.7.2", - "@unrs/resolver-binding-win32-arm64-msvc": "1.7.2", - "@unrs/resolver-binding-win32-ia32-msvc": "1.7.2", - "@unrs/resolver-binding-win32-x64-msvc": "1.7.2" + "@unrs/resolver-binding-darwin-arm64": "1.7.10", + "@unrs/resolver-binding-darwin-x64": "1.7.10", + "@unrs/resolver-binding-freebsd-x64": "1.7.10", + "@unrs/resolver-binding-linux-arm-gnueabihf": "1.7.10", + "@unrs/resolver-binding-linux-arm-musleabihf": "1.7.10", + "@unrs/resolver-binding-linux-arm64-gnu": "1.7.10", + "@unrs/resolver-binding-linux-arm64-musl": "1.7.10", + "@unrs/resolver-binding-linux-ppc64-gnu": "1.7.10", + "@unrs/resolver-binding-linux-riscv64-gnu": "1.7.10", + "@unrs/resolver-binding-linux-riscv64-musl": "1.7.10", + "@unrs/resolver-binding-linux-s390x-gnu": "1.7.10", + "@unrs/resolver-binding-linux-x64-gnu": "1.7.10", + "@unrs/resolver-binding-linux-x64-musl": "1.7.10", + "@unrs/resolver-binding-wasm32-wasi": "1.7.10", + "@unrs/resolver-binding-win32-arm64-msvc": "1.7.10", + "@unrs/resolver-binding-win32-ia32-msvc": "1.7.10", + "@unrs/resolver-binding-win32-x64-msvc": "1.7.10" } }, "node_modules/uri-js": { @@ -9091,6 +9301,16 @@ "dev": true, "license": "ISC" }, + "node_modules/yallist": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/yallist/-/yallist-5.0.0.tgz", + "integrity": "sha512-YgvUTfwqyc7UXVMrB+SImsVYSmTS8X/tSrtdNZMImM+n7+QTriRXyXim0mBrTXNeqzVF0KWGgHPeiyViFFrNDw==", + "dev": true, + "license": "BlueOak-1.0.0", + "engines": { + "node": ">=18" + } + }, "node_modules/yaml": { "version": "2.3.1", "resolved": "https://registry.npmjs.org/yaml/-/yaml-2.3.1.tgz", @@ -9115,18 +9335,18 @@ } }, "node_modules/zod": { - "version": "3.24.3", - "resolved": "https://registry.npmjs.org/zod/-/zod-3.24.3.tgz", - "integrity": "sha512-HhY1oqzWCQWuUqvBFnsyrtZRhyPeR7SUGv+C4+MsisMuVfSPx8HpwWqH8tRahSlt6M3PiFAcoeFhZAqIXTxoSg==", + "version": "3.25.51", + "resolved": "https://registry.npmjs.org/zod/-/zod-3.25.51.tgz", + "integrity": "sha512-TQSnBldh+XSGL+opiSIq0575wvDPqu09AqWe1F7JhUMKY+M91/aGlK4MhpVNO7MgYfHcVCB1ffwAUTJzllKJqg==", "license": "MIT", "funding": { "url": "https://github.com/sponsors/colinhacks" } }, "node_modules/zustand": { - "version": "5.0.3", - "resolved": "https://registry.npmjs.org/zustand/-/zustand-5.0.3.tgz", - "integrity": "sha512-14fwWQtU3pH4dE0dOpdMiWjddcH+QzKIgk1cl8epwSE7yag43k/AD/m4L6+K7DytAOr9gGBe3/EXj9g7cdostg==", + "version": "5.0.5", + "resolved": "https://registry.npmjs.org/zustand/-/zustand-5.0.5.tgz", + "integrity": "sha512-mILtRfKW9xM47hqxGIxCv12gXusoY/xTSHBYApXozR0HmQv299whhBeeAcRy+KrPPybzosvJBCOmVjq6x12fCg==", "license": "MIT", "engines": { "node": ">=12.20.0" diff --git a/frontend/tsconfig.json b/frontend/tsconfig.json index 7b28589..f48e7ee 100644 --- a/frontend/tsconfig.json +++ b/frontend/tsconfig.json @@ -1,6 +1,10 @@ { "compilerOptions": { - "lib": ["dom", "dom.iterable", "esnext"], + "lib": [ + "dom", + "dom.iterable", + "esnext" + ], "allowJs": true, "skipLibCheck": true, "strict": true, @@ -18,9 +22,19 @@ } ], "paths": { - "@/*": ["./src/*"] - } + "@/*": [ + "./src/*" + ] + }, + "target": "ES2017" }, - "include": ["next-env.d.ts", "**/*.ts", "**/*.tsx", ".next/types/**/*.ts"], - "exclude": ["node_modules"] + "include": [ + "next-env.d.ts", + "**/*.ts", + "**/*.tsx", + ".next/types/**/*.ts" + ], + "exclude": [ + "node_modules" + ] } From 2f63c141cd800c349caf490a60344fea4e5acda2 Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Fri, 6 Jun 2025 12:13:12 +0300 Subject: [PATCH 55/74] experiment tests for mabs --- backend/tests/test_experiments.py | 399 ++++++++++++++++++++++++++++++ 1 file changed, 399 insertions(+) create mode 100644 backend/tests/test_experiments.py diff --git a/backend/tests/test_experiments.py b/backend/tests/test_experiments.py new file mode 100644 index 0000000..0173d70 --- /dev/null +++ b/backend/tests/test_experiments.py @@ -0,0 +1,399 @@ +import copy +import os +from typing import Generator + +from fastapi.testclient import TestClient +from pytest import FixtureRequest, fixture, mark +from sqlalchemy.orm import Session + +from backend.app.experiments.models import ArmDB, ExperimentDB, NotificationsDB + +mab_beta_binom_payload = { + "name": "Test", + "description": "Test description.", + "exp_type": "mab", + "prior_type": "beta", + "reward_type": "binary", + "arms": [ + { + "name": "arm 1", + "description": "arm 1 description.", + "alpha_init": 5, + "beta_init": 1, + }, + { + "name": "arm 2", + "description": "arm 2 description.", + "alpha_init": 1, + "beta_init": 4, + }, + ], + "notifications": { + "onTrialCompletion": True, + "numberOfTrials": 2, + "onDaysElapsed": False, + "daysElapsed": 3, + "onPercentBetter": False, + "percentBetterThreshold": 5, + }, + "contexts": [], + "clients": [], +} + + +@fixture +def admin_token(client: TestClient) -> str: + response = client.post( + "/login", + data={ + "username": os.environ.get("ADMIN_USERNAME", ""), + "password": os.environ.get("ADMIN_PASSWORD", ""), + }, + ) + token = response.json()["access_token"] + return token + + +@fixture +def clean_experiments(db_session: Session) -> Generator: + yield + db_session.query(NotificationsDB).delete() + db_session.query(ArmDB).delete() + db_session.query(ExperimentDB).delete() + db_session.commit() + + +class TestExperiment: + @fixture + def create_experiment_payload(self, request: FixtureRequest) -> dict: + payload_mab_beta_binom: dict = copy.deepcopy(mab_beta_binom_payload) + payload_mab_beta_binom["arms"] = list(payload_mab_beta_binom["arms"]) + + payload_mab_normal: dict = copy.deepcopy(mab_beta_binom_payload) + payload_mab_normal["prior_type"] = "normal" + payload_mab_normal["reward_type"] = "real-valued" + payload_mab_normal["arms"] = [ + { + "name": "arm 1", + "description": "arm 1 description", + "mu_init": 2, + "sigma_init": 3, + }, + { + "name": "arm 2", + "description": "arm 2 description", + "mu_init": 3, + "sigma_init": 7, + }, + ] + + if request.param == "base_beta_binom": + return payload_mab_beta_binom + if request.param == "base_normal": + return payload_mab_normal + if request.param == "one_arm": + payload_mab_beta_binom["arms"].pop() + return payload_mab_beta_binom + if request.param == "no_notifications": + payload_mab_beta_binom["notifications"]["onTrialCompletion"] = False + return payload_mab_beta_binom + if request.param == "invalid_prior": + payload_mab_beta_binom["prior_type"] = "invalid" + return payload_mab_beta_binom + if request.param == "invalid_reward": + payload_mab_beta_binom["reward_type"] = "invalid" + return payload_mab_beta_binom + if request.param == "invalid_alpha": + payload_mab_beta_binom["arms"][0]["alpha_init"] = -1 + return payload_mab_beta_binom + if request.param == "invalid_beta": + payload_mab_beta_binom["arms"][0]["beta_init"] = -1 + return payload_mab_beta_binom + if request.param == "invalid_combo": + payload_mab_beta_binom["reward_type"] = "real-valued" + return payload_mab_beta_binom + if request.param == "incorrect_params": + payload_mab_beta_binom["arms"][0].pop("alpha_init") + return payload_mab_beta_binom + if request.param == "invalid_sigma": + payload_mab_normal["arms"][0]["sigma_init"] = 0.0 + return payload_mab_normal + if request.param == "invalid_context_input": + payload_mab_beta_binom["contexts"] = [ + { + "name": "context 1", + "description": "context 1 description", + "value_type": "binary", + } + ] + return payload_mab_beta_binom + else: + raise ValueError("Invalid parameter") + + @mark.parametrize( + "create_experiment_payload, expected_response", + [ + ("base_beta_binom", 200), + ("base_normal", 200), + ("one_arm", 422), + ("no_notifications", 200), + ("invalid_prior", 422), + ("invalid_reward", 422), + ("invalid_alpha", 422), + ("invalid_beta", 422), + ("invalid_sigma", 422), + ("invalid_combo", 422), + ("incorrect_params", 422), + ("invalid_context_input", 422), + ], + indirect=["create_experiment_payload"], + ) + def test_create_experiment( + self, + create_experiment_payload: dict, + client: TestClient, + expected_response: int, + admin_token: str, + clean_experiments: None, + ) -> None: + response = client.post( + "/experiment", + json=create_experiment_payload, + headers={"Authorization": f"Bearer {admin_token}"}, + ) + + assert response.status_code == expected_response + + @fixture + def create_experiments( + self, + client: TestClient, + admin_token: str, + request: FixtureRequest, + create_experiment_payload: dict, + ) -> Generator: + experiments = [] + n_experiments = request.param if hasattr(request, "param") else 1 + for _ in range(n_experiments): + response = client.post( + "/experiment", + json=create_experiment_payload, + headers={"Authorization": f"Bearer {admin_token}"}, + ) + experiments.append(response.json()) + yield experiments + for experiment in experiments: + client.delete( + f"/experiment/id/{experiment['experiment_id']}", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + + @mark.parametrize( + "create_experiments, create_experiment_payload, n_expected", + [ + (0, "base_beta_binom", 0), + (2, "base_beta_binom", 2), + (5, "base_beta_binom", 5), + ], + indirect=["create_experiments", "create_experiment_payload"], + ) + def test_get_all_experiments( + self, + client: TestClient, + admin_token: str, + n_expected: int, + create_experiments: list, + create_experiment_payload: dict, + ) -> None: + response = client.get( + "/experiment", headers={"Authorization": f"Bearer {admin_token}"} + ) + assert response.status_code == 200 + assert len(response.json()) == n_expected + + @mark.parametrize( + "create_experiments, create_experiment_payload, expected_response", + [(0, "base_beta_binom", 404), (2, "base_beta_binom", 200)], + indirect=["create_experiments", "create_experiment_payload"], + ) + def test_get_experiment( + self, + client: TestClient, + admin_token: str, + create_experiments: list, + create_experiment_payload: dict, + expected_response: int, + ) -> None: + id = create_experiments[0]["experiment_id"] if create_experiments else 999 + + response = client.get( + f"/experiment/id/{id}/", headers={"Authorization": f"Bearer {admin_token}"} + ) + assert response.status_code == expected_response + + @mark.parametrize("create_experiment_payload", ["base_beta_binom"], indirect=True) + def test_draw_arm_draw_id_provided( + self, + client: TestClient, + create_experiments: list, + create_experiment_payload: dict, + workspace_api_key: str, + ) -> None: + id = create_experiments[0]["experiment_id"] + response = client.put( + f"/experiment/{id}/draw", + params={"draw_id": "test_draw"}, + headers={"Authorization": f"Bearer {workspace_api_key}"}, + ) + assert response.status_code == 200 + assert response.json()["draw_id"] == "test_draw" + + @mark.parametrize("create_experiment_payload", ["base_beta_binom"], indirect=True) + def test_draw_arm_no_draw_id_provided( + self, + client: TestClient, + create_experiments: list, + create_experiment_payload: dict, + workspace_api_key: str, + ) -> None: + id = create_experiments[0]["experiment_id"] + response = client.put( + f"/experiment/{id}/draw", + headers={"Authorization": f"Bearer {workspace_api_key}"}, + ) + assert response.status_code == 200 + assert len(response.json()["draw_id"]) == 36 + + @mark.parametrize("create_experiment_payload", ["base_beta_binom"], indirect=True) + def test_one_outcome_per_draw( + self, + client: TestClient, + create_experiments: list, + create_experiment_payload: dict, + workspace_api_key: str, + ) -> None: + id = create_experiments[0]["experiment_id"] + response = client.put( + f"/experiment/{id}/draw", + headers={"Authorization": f"Bearer {workspace_api_key}"}, + ) + assert response.status_code == 200 + draw_id = response.json()["draw_id"] + + response = client.put( + f"/experiment/{id}/{draw_id}/1", + headers={"Authorization": f"Bearer {workspace_api_key}"}, + ) + + assert response.status_code == 200 + + response = client.put( + f"/experiment/{id}/{draw_id}/1", + headers={"Authorization": f"Bearer {workspace_api_key}"}, + ) + + assert response.status_code == 400 + + @mark.parametrize( + "n_draws, create_experiment_payload", + [(0, "base_beta_binom"), (1, "base_beta_binom"), (5, "base_beta_binom")], + indirect=["create_experiment_payload"], + ) + def test_get_rewards( + self, + client: TestClient, + create_experiments: list, + n_draws: int, + create_experiment_payload: dict, + workspace_api_key: str, + ) -> None: + id = create_experiments[0]["experiment_id"] + + for _ in range(n_draws): + response = client.put( + f"/experiment/{id}/draw", + headers={"Authorization": f"Bearer {workspace_api_key}"}, + ) + assert response.status_code == 200 + draw_id = response.json()["draw_id"] + # put outcomes + response = client.put( + f"/experiment/{id}/{draw_id}/1", + headers={"Authorization": f"Bearer {workspace_api_key}"}, + ) + + response = client.get( + f"/experiment/{id}/rewards", + headers={"Authorization": f"Bearer {workspace_api_key}"}, + ) + + assert response.status_code == 200 + assert len(response.json()) == n_draws + + +class TestNotifications: + @fixture() + def create_experiment_payload(self, request: FixtureRequest) -> dict: + payload: dict = copy.deepcopy(mab_beta_binom_payload) + payload["arms"] = list(payload["arms"]) + + match request.param: + case "base": + pass + case "daysElapsed_only": + payload["notifications"]["onTrialCompletion"] = False + payload["notifications"]["onDaysElapsed"] = True + case "trialCompletion_only": + payload["notifications"]["onTrialCompletion"] = True + case "percentBetter_only": + payload["notifications"]["onTrialCompletion"] = False + payload["notifications"]["onPercentBetter"] = True + case "all_notifications": + payload["notifications"]["onDaysElapsed"] = True + payload["notifications"]["onPercentBetter"] = True + case "no_notifications": + payload["notifications"]["onTrialCompletion"] = False + case "daysElapsed_missing": + payload["notifications"]["daysElapsed"] = 0 + payload["notifications"]["onDaysElapsed"] = True + case "trialCompletion_missing": + payload["notifications"]["numberOfTrials"] = 0 + payload["notifications"]["onTrialCompletion"] = True + case "percentBetter_missing": + payload["notifications"]["percentBetterThreshold"] = 0 + payload["notifications"]["onPercentBetter"] = True + case _: + raise ValueError("Invalid parameter") + + return payload + + @mark.parametrize( + "create_experiment_payload, expected_response", + [ + ("base", 200), + ("daysElapsed_only", 200), + ("trialCompletion_only", 200), + ("percentBetter_only", 200), + ("all_notifications", 200), + ("no_notifications", 200), + ("daysElapsed_missing", 422), + ("trialCompletion_missing", 422), + ("percentBetter_missing", 422), + ], + indirect=["create_experiment_payload"], + ) + def test_notifications( + self, + client: TestClient, + admin_token: str, + create_experiment_payload: dict, + expected_response: int, + clean_experiments: None, + ) -> None: + response = client.post( + "/experiment", + json=create_experiment_payload, + headers={"Authorization": f"Bearer {admin_token}"}, + ) + + assert response.status_code == expected_response From 09aecdccabf77ea9c422e0d3af1c9c7c9ce95adb Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Fri, 6 Jun 2025 15:55:18 +0300 Subject: [PATCH 56/74] update tests for Bayes AB experiments --- backend/tests/test_experiments.py | 180 +++++++++++++++++++++++------- 1 file changed, 140 insertions(+), 40 deletions(-) diff --git a/backend/tests/test_experiments.py b/backend/tests/test_experiments.py index 0173d70..3a268ba 100644 --- a/backend/tests/test_experiments.py +++ b/backend/tests/test_experiments.py @@ -20,12 +20,14 @@ "description": "arm 1 description.", "alpha_init": 5, "beta_init": 1, + "is_treatment_arm": True, }, { "name": "arm 2", "description": "arm 2 description.", "alpha_init": 1, "beta_init": 4, + "is_treatment_arm": False, }, ], "notifications": { @@ -63,62 +65,64 @@ def clean_experiments(db_session: Session) -> Generator: db_session.commit() -class TestExperiment: - @fixture - def create_experiment_payload(self, request: FixtureRequest) -> dict: - payload_mab_beta_binom: dict = copy.deepcopy(mab_beta_binom_payload) - payload_mab_beta_binom["arms"] = list(payload_mab_beta_binom["arms"]) - - payload_mab_normal: dict = copy.deepcopy(mab_beta_binom_payload) - payload_mab_normal["prior_type"] = "normal" - payload_mab_normal["reward_type"] = "real-valued" - payload_mab_normal["arms"] = [ - { - "name": "arm 1", - "description": "arm 1 description", - "mu_init": 2, - "sigma_init": 3, - }, - { - "name": "arm 2", - "description": "arm 2 description", - "mu_init": 3, - "sigma_init": 7, - }, - ] - - if request.param == "base_beta_binom": +def _get_experiment_payload(input: str) -> dict: + """Helper function to get the experiment payload based on input.""" + payload_mab_beta_binom: dict = copy.deepcopy(mab_beta_binom_payload) + payload_mab_beta_binom["arms"] = list(payload_mab_beta_binom["arms"]) + + payload_mab_normal: dict = copy.deepcopy(mab_beta_binom_payload) + payload_mab_normal["prior_type"] = "normal" + payload_mab_normal["reward_type"] = "real-valued" + payload_mab_normal["arms"] = [ + { + "name": "arm 1", + "description": "arm 1 description", + "mu_init": 2, + "sigma_init": 3, + "is_treatment_arm": True, + }, + { + "name": "arm 2", + "description": "arm 2 description", + "mu_init": 3, + "sigma_init": 7, + "is_treatment_arm": True, + }, + ] + + match input: + case "base_beta_binom": return payload_mab_beta_binom - if request.param == "base_normal": + case "base_normal": return payload_mab_normal - if request.param == "one_arm": + case "one_arm": payload_mab_beta_binom["arms"].pop() return payload_mab_beta_binom - if request.param == "no_notifications": + case "no_notifications": payload_mab_beta_binom["notifications"]["onTrialCompletion"] = False return payload_mab_beta_binom - if request.param == "invalid_prior": + case "invalid_prior": payload_mab_beta_binom["prior_type"] = "invalid" return payload_mab_beta_binom - if request.param == "invalid_reward": + case "invalid_reward": payload_mab_beta_binom["reward_type"] = "invalid" return payload_mab_beta_binom - if request.param == "invalid_alpha": + case "invalid_alpha": payload_mab_beta_binom["arms"][0]["alpha_init"] = -1 return payload_mab_beta_binom - if request.param == "invalid_beta": + case "invalid_beta": payload_mab_beta_binom["arms"][0]["beta_init"] = -1 return payload_mab_beta_binom - if request.param == "invalid_combo": + case "invalid_combo": payload_mab_beta_binom["reward_type"] = "real-valued" return payload_mab_beta_binom - if request.param == "incorrect_params": + case "incorrect_params": payload_mab_beta_binom["arms"][0].pop("alpha_init") return payload_mab_beta_binom - if request.param == "invalid_sigma": + case "invalid_sigma": payload_mab_normal["arms"][0]["sigma_init"] = 0.0 return payload_mab_normal - if request.param == "invalid_context_input": + case "invalid_context_input": payload_mab_beta_binom["contexts"] = [ { "name": "context 1", @@ -127,8 +131,42 @@ def create_experiment_payload(self, request: FixtureRequest) -> dict: } ] return payload_mab_beta_binom - else: - raise ValueError("Invalid parameter") + case "bayes_ab_normal_binom": + payload_mab_normal["exp_type"] = "bayes_ab" + payload_mab_normal["reward_type"] = "real-valued" + payload_mab_normal["arms"][1]["is_treatment_arm"] = False + return payload_mab_normal + case "bayes_ab_invalid_prior": + payload_mab_beta_binom["exp_type"] = "bayes_ab" + payload_mab_beta_binom["arms"][1]["is_treatment_arm"] = False + return payload_mab_beta_binom + case "bayes_ab_invalid_arm": + payload_mab_normal["exp_type"] = "bayes_ab" + payload_mab_normal["reward_type"] = "real-valued" + return payload_mab_normal + case "bayes_ab_invalid_context": + payload_mab_normal["exp_type"] = "bayes_ab" + payload_mab_normal["reward_type"] = "real-valued" + payload_mab_normal["arms"][1]["is_treatment_arm"] = False + payload_mab_normal["contexts"] = [ + { + "name": "context 1", + "description": "context 1 description", + "value_type": "binary", + } + ] + return payload_mab_normal + case _: + raise ValueError(f"Invalid input: {input}.") + + +class TestExperiment: + @fixture + def create_experiment_payload(self, request: FixtureRequest) -> dict: + """Fixture to create experiment payload based on request parameter.""" + return ( + _get_experiment_payload(request.param) if hasattr(request, "param") else {} + ) @mark.parametrize( "create_experiment_payload, expected_response", @@ -145,6 +183,10 @@ def create_experiment_payload(self, request: FixtureRequest) -> dict: ("invalid_combo", 422), ("incorrect_params", 422), ("invalid_context_input", 422), + ("bayes_ab_normal_binom", 200), + ("bayes_ab_invalid_prior", 422), + ("bayes_ab_invalid_arm", 422), + ("bayes_ab_invalid_context", 422), ], indirect=["create_experiment_payload"], ) @@ -188,6 +230,29 @@ def create_experiments( headers={"Authorization": f"Bearer {admin_token}"}, ) + @fixture + def create_mixed_experiments( + self, + client: TestClient, + admin_token: str, + request: FixtureRequest, + ) -> Generator: + mixed_payload = [] + for param in request.param: + payload = _get_experiment_payload(param) + response = client.post( + "/experiment", + json=payload, + headers={"Authorization": f"Bearer {admin_token}"}, + ) + mixed_payload.append(response.json()) + yield mixed_payload + for experiment in mixed_payload: + client.delete( + f"/experiment/id/{experiment['experiment_id']}", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + @mark.parametrize( "create_experiments, create_experiment_payload, n_expected", [ @@ -211,6 +276,30 @@ def test_get_all_experiments( assert response.status_code == 200 assert len(response.json()) == n_expected + @mark.parametrize( + "create_mixed_experiments, exp_type, n_expected", + [ + (["base_beta_binom", "bayes_ab_normal_binom"], "mab", 1), + (["base_beta_binom", "bayes_ab_normal_binom"], "bayes_ab", 1), + (["base_beta_binom", "bayes_ab_normal_binom"], "cmab", 0), + ], + indirect=["create_mixed_experiments"], + ) + def test_get_all_experiments_by_type( + self, + client: TestClient, + admin_token: str, + n_expected: int, + create_mixed_experiments: list, + exp_type: str, + ) -> None: + response = client.get( + f"/experiment/type/{exp_type}", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert response.status_code == 200 + assert len(response.json()) == n_expected + @mark.parametrize( "create_experiments, create_experiment_payload, expected_response", [(0, "base_beta_binom", 404), (2, "base_beta_binom", 200)], @@ -264,7 +353,11 @@ def test_draw_arm_no_draw_id_provided( assert response.status_code == 200 assert len(response.json()["draw_id"]) == 36 - @mark.parametrize("create_experiment_payload", ["base_beta_binom"], indirect=True) + @mark.parametrize( + "create_experiment_payload", + ["base_beta_binom", "bayes_ab_normal_binom"], + indirect=True, + ) def test_one_outcome_per_draw( self, client: TestClient, @@ -296,7 +389,14 @@ def test_one_outcome_per_draw( @mark.parametrize( "n_draws, create_experiment_payload", - [(0, "base_beta_binom"), (1, "base_beta_binom"), (5, "base_beta_binom")], + [ + (0, "base_beta_binom"), + (1, "base_beta_binom"), + (5, "base_beta_binom"), + (0, "bayes_ab_normal_binom"), + (1, "bayes_ab_normal_binom"), + (5, "bayes_ab_normal_binom"), + ], indirect=["create_experiment_payload"], ) def test_get_rewards( From 87013243ec44e91cfaf0014868927449b03f9514 Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Fri, 6 Jun 2025 16:53:33 +0300 Subject: [PATCH 57/74] add tests for CMAB --- backend/tests/test_experiments.py | 121 ++++++++++++++++++++++++++++-- 1 file changed, 113 insertions(+), 8 deletions(-) diff --git a/backend/tests/test_experiments.py b/backend/tests/test_experiments.py index 3a268ba..f2bdd7c 100644 --- a/backend/tests/test_experiments.py +++ b/backend/tests/test_experiments.py @@ -6,7 +6,12 @@ from pytest import FixtureRequest, fixture, mark from sqlalchemy.orm import Session -from backend.app.experiments.models import ArmDB, ExperimentDB, NotificationsDB +from backend.app.experiments.models import ( + ArmDB, + ContextDB, + ExperimentDB, + NotificationsDB, +) mab_beta_binom_payload = { "name": "Test", @@ -60,6 +65,7 @@ def admin_token(client: TestClient) -> str: def clean_experiments(db_session: Session) -> Generator: yield db_session.query(NotificationsDB).delete() + db_session.query(ContextDB).delete() db_session.query(ArmDB).delete() db_session.query(ExperimentDB).delete() db_session.commit() @@ -156,6 +162,57 @@ def _get_experiment_payload(input: str) -> dict: } ] return payload_mab_normal + case "cmab_normal": + payload_mab_normal["exp_type"] = "cmab" + payload_mab_normal["contexts"] = [ + { + "name": "context 1", + "description": "context 1 description", + "value_type": "binary", + }, + { + "name": "context 2", + "description": "context 2 description", + "value_type": "real-valued", + }, + ] + return payload_mab_normal + case "cmab_normal_binomial": + payload_mab_normal["exp_type"] = "cmab" + payload_mab_normal["reward_type"] = "binary" + payload_mab_normal["contexts"] = [ + { + "name": "context 1", + "description": "context 1 description", + "value_type": "binary", + }, + { + "name": "context 2", + "description": "context 2 description", + "value_type": "real-valued", + }, + ] + return payload_mab_normal + case "cmab_invalid_prior": + payload_mab_normal["exp_type"] = "cmab" + payload_mab_normal["prior_type"] = "beta" + payload_mab_normal["contexts"] = [ + { + "name": "context 1", + "description": "context 1 description", + "value_type": "binary", + }, + { + "name": "context 2", + "description": "context 2 description", + "value_type": "real-valued", + }, + ] + return payload_mab_normal + case "cmab_invalid_context": + payload_mab_normal["exp_type"] = "cmab" + return payload_mab_normal + case _: raise ValueError(f"Invalid input: {input}.") @@ -164,9 +221,7 @@ class TestExperiment: @fixture def create_experiment_payload(self, request: FixtureRequest) -> dict: """Fixture to create experiment payload based on request parameter.""" - return ( - _get_experiment_payload(request.param) if hasattr(request, "param") else {} - ) + return _get_experiment_payload(request.param) @mark.parametrize( "create_experiment_payload, expected_response", @@ -187,6 +242,10 @@ def create_experiment_payload(self, request: FixtureRequest) -> dict: ("bayes_ab_invalid_prior", 422), ("bayes_ab_invalid_arm", 422), ("bayes_ab_invalid_context", 422), + ("cmab_normal", 200), + ("cmab_normal_binomial", 200), + ("cmab_invalid_prior", 422), + ("cmab_invalid_context", 422), ], indirect=["create_experiment_payload"], ) @@ -279,9 +338,36 @@ def test_get_all_experiments( @mark.parametrize( "create_mixed_experiments, exp_type, n_expected", [ - (["base_beta_binom", "bayes_ab_normal_binom"], "mab", 1), - (["base_beta_binom", "bayes_ab_normal_binom"], "bayes_ab", 1), - (["base_beta_binom", "bayes_ab_normal_binom"], "cmab", 0), + ( + [ + "base_beta_binom", + "base_normal", + "bayes_ab_normal_binom", + "cmab_normal", + ], + "mab", + 2, + ), + ( + [ + "base_beta_binom", + "bayes_ab_normal_binom", + "bayes_ab_normal_binom", + "cmab_normal", + ], + "bayes_ab", + 2, + ), + ( + [ + "base_beta_binom", + "bayes_ab_normal_binom", + "cmab_normal", + "cmab_normal_binomial", + ], + "cmab", + 2, + ), ], indirect=["create_mixed_experiments"], ) @@ -355,7 +441,7 @@ def test_draw_arm_no_draw_id_provided( @mark.parametrize( "create_experiment_payload", - ["base_beta_binom", "bayes_ab_normal_binom"], + ["base_beta_binom", "bayes_ab_normal_binom", "cmab_normal"], indirect=True, ) def test_one_outcome_per_draw( @@ -366,9 +452,17 @@ def test_one_outcome_per_draw( workspace_api_key: str, ) -> None: id = create_experiments[0]["experiment_id"] + exp_type = create_experiments[0]["exp_type"] + contexts = None + if exp_type == "cmab": + contexts = [ + {"context_id": context["context_id"], "context_value": 1} + for context in create_experiments[0]["contexts"] + ] response = client.put( f"/experiment/{id}/draw", headers={"Authorization": f"Bearer {workspace_api_key}"}, + json=contexts, ) assert response.status_code == 200 draw_id = response.json()["draw_id"] @@ -396,6 +490,9 @@ def test_one_outcome_per_draw( (0, "bayes_ab_normal_binom"), (1, "bayes_ab_normal_binom"), (5, "bayes_ab_normal_binom"), + (0, "cmab_normal"), + (1, "cmab_normal"), + (5, "cmab_normal"), ], indirect=["create_experiment_payload"], ) @@ -408,11 +505,19 @@ def test_get_rewards( workspace_api_key: str, ) -> None: id = create_experiments[0]["experiment_id"] + exp_type = create_experiments[0]["exp_type"] + contexts = None + if exp_type == "cmab": + contexts = [ + {"context_id": context["context_id"], "context_value": 1} + for context in create_experiments[0]["contexts"] + ] for _ in range(n_draws): response = client.put( f"/experiment/{id}/draw", headers={"Authorization": f"Bearer {workspace_api_key}"}, + json=contexts, ) assert response.status_code == 200 draw_id = response.json()["draw_id"] From 4a48364ae2eff4342e2ac191e443acd3f9ba8065 Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Fri, 6 Jun 2025 16:54:46 +0300 Subject: [PATCH 58/74] delete old tests --- backend/tests/test_bayes_ab.py | 430 ------------------------------ backend/tests/test_cmabs.py | 452 -------------------------------- backend/tests/test_mabs.py | 463 --------------------------------- 3 files changed, 1345 deletions(-) delete mode 100644 backend/tests/test_bayes_ab.py delete mode 100644 backend/tests/test_cmabs.py delete mode 100644 backend/tests/test_mabs.py diff --git a/backend/tests/test_bayes_ab.py b/backend/tests/test_bayes_ab.py deleted file mode 100644 index 6b75b6f..0000000 --- a/backend/tests/test_bayes_ab.py +++ /dev/null @@ -1,430 +0,0 @@ -import copy -import os -from typing import Generator - -import numpy as np -from fastapi.testclient import TestClient -from pytest import FixtureRequest, fixture, mark -from sqlalchemy.orm import Session - -from backend.app.bayes_ab.models import BayesianABArmDB, BayesianABDB -from backend.app.models import NotificationsDB - -base_normal_payload = { - "name": "Test", - "description": "Test description", - "prior_type": "normal", - "reward_type": "real-valued", - "sticky_assignment": False, - "arms": [ - { - "name": "arm 1", - "description": "arm 1 description", - "mu_init": 0, - "sigma_init": 1, - "is_treatment_arm": True, - }, - { - "name": "arm 2", - "description": "arm 2 description", - "mu_init": 2, - "sigma_init": 2, - "is_treatment_arm": False, - }, - ], - "notifications": { - "onTrialCompletion": True, - "numberOfTrials": 2, - "onDaysElapsed": False, - "daysElapsed": 3, - "onPercentBetter": False, - "percentBetterThreshold": 5, - }, -} - -base_binary_normal_payload = base_normal_payload.copy() -base_binary_normal_payload["reward_type"] = "binary" - - -@fixture -def clean_bayes_ab(db_session: Session) -> Generator: - """ - Fixture to clean the database before each test. - """ - yield - db_session.query(NotificationsDB).delete() - db_session.query(BayesianABArmDB).delete() - db_session.query(BayesianABDB).delete() - - db_session.commit() - - -@fixture -def admin_token(client: TestClient) -> str: - """Get a token for the admin user""" - response = client.post( - "/login", - data={ - "username": os.environ.get("ADMIN_USERNAME", "admin@idinsight.org"), - "password": os.environ.get("ADMIN_PASSWORD", "12345"), - }, - ) - assert response.status_code == 200, f"Login failed: {response.json()}" - token = response.json()["access_token"] - return token - - -class TestBayesAB: - """ - Test class for Bayesian A/B testing. - """ - - @fixture - def create_bayes_ab_payload(self, request: FixtureRequest) -> dict: - """ - Fixture to create a payload for the Bayesian A/B test. - """ - payload_normal: dict = copy.deepcopy(base_normal_payload) - payload_normal["arms"] = list(payload_normal["arms"]) - - payload_binary_normal: dict = copy.deepcopy(base_binary_normal_payload) - payload_binary_normal["arms"] = list(payload_binary_normal["arms"]) - - if request.param == "base_normal": - return payload_normal - if request.param == "base_binary_normal": - return payload_binary_normal - if request.param == "one_arm": - payload_normal["arms"].pop() - return payload_normal - if request.param == "no_notifications": - payload_normal["notifications"]["onTrialCompletion"] = False - return payload_normal - if request.param == "invalid_prior": - payload_normal["prior_type"] = "beta" - return payload_normal - if request.param == "invalid_sigma": - payload_normal["arms"][0]["sigma_init"] = 0 - return payload_normal - if request.param == "invalid_params": - payload_normal["arms"][0].pop("mu_init") - return payload_normal - if request.param == "two_treatment_arms": - payload_normal["arms"][0]["is_treatment_arm"] = True - payload_normal["arms"][1]["is_treatment_arm"] = True - return payload_normal - if request.param == "with_sticky_assignment": - payload_normal["sticky_assignment"] = True - return payload_normal - else: - raise ValueError("Invalid parameter") - - @fixture - def create_bayes_abs( - self, - client: TestClient, - admin_token: str, - create_bayes_ab_payload: dict, - request: FixtureRequest, - ) -> Generator: - bayes_abs = [] - n_bayes_abs = request.param if hasattr(request, "param") else 1 - for _ in range(n_bayes_abs): - response = client.post( - "/bayes_ab", - json=create_bayes_ab_payload, - headers={"Authorization": f"Bearer {admin_token}"}, - ) - bayes_abs.append(response.json()) - yield bayes_abs - for bayes_ab in bayes_abs: - client.delete( - f"/bayes_ab/{bayes_ab['experiment_id']}", - headers={"Authorization": f"Bearer {admin_token}"}, - ) - - @mark.parametrize( - "create_bayes_ab_payload, expected_response", - [ - ("base_normal", 200), - ("base_binary_normal", 200), - ("one_arm", 422), - ("no_notifications", 200), - ("invalid_prior", 422), - ("invalid_sigma", 422), - ("invalid_params", 200), - ("two_treatment_arms", 422), - ], - indirect=["create_bayes_ab_payload"], - ) - def test_create_bayes_ab( - self, - create_bayes_ab_payload: dict, - client: TestClient, - expected_response: int, - admin_token: str, - clean_bayes_ab: None, - ) -> None: - """ - Test the creation of a Bayesian A/B test. - """ - response = client.post( - "/bayes_ab", - json=create_bayes_ab_payload, - headers={"Authorization": f"Bearer {admin_token}"}, - ) - - assert response.status_code == expected_response - - @mark.parametrize( - "create_bayes_abs, n_expected, create_bayes_ab_payload", - [(1, 1, "base_normal"), (2, 2, "base_normal"), (5, 5, "base_normal")], - indirect=["create_bayes_abs", "create_bayes_ab_payload"], - ) - def test_get_bayes_abs( - self, - client: TestClient, - n_expected: int, - admin_token: str, - create_bayes_abs: list, - create_bayes_ab_payload: dict, - ) -> None: - """ - Test the retrieval of Bayesian A/B tests. - """ - response = client.get( - "/bayes_ab", headers={"Authorization": f"Bearer {admin_token}"} - ) - - assert response.status_code == 200 - assert len(response.json()) == n_expected - - @mark.parametrize( - "create_bayes_abs, expected_response, create_bayes_ab_payload", - [(1, 200, "base_normal"), (2, 200, "base_normal"), (5, 200, "base_normal")], - indirect=["create_bayes_abs", "create_bayes_ab_payload"], - ) - def test_draw_arm( - self, - client: TestClient, - create_bayes_abs: list, - create_bayes_ab_payload: dict, - expected_response: int, - workspace_api_key: str, - ) -> None: - id = create_bayes_abs[0]["experiment_id"] - response = client.get( - f"/bayes_ab/{id}/draw", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - assert response.status_code == expected_response - - @mark.parametrize( - "create_bayes_ab_payload, client_id, expected_response", - [ - ("with_sticky_assignment", None, 400), - ("with_sticky_assignment", "test_client_id", 200), - ], - indirect=["create_bayes_ab_payload"], - ) - def test_draw_arm_with_client_id( - self, - client: TestClient, - create_bayes_abs: list, - create_bayes_ab_payload: dict, - client_id: str | None, - expected_response: int, - workspace_api_key: str, - ) -> None: - id = create_bayes_abs[0]["experiment_id"] - response = client.get( - f"/bayes_ab/{id}/draw{'?client_id=' + client_id if client_id else ''}", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - assert response.status_code == expected_response - - @mark.parametrize( - "create_bayes_ab_payload", ["with_sticky_assignment"], indirect=True - ) - def test_draw_arm_with_sticky_assignment( - self, - client: TestClient, - create_bayes_abs: list, - create_bayes_ab_payload: dict, - workspace_api_key: str, - ) -> None: - id = create_bayes_abs[0]["experiment_id"] - arm_ids = [] - for _ in range(10): - response = client.get( - f"/bayes_ab/{id}/draw?client_id=123", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - arm_ids.append(response.json()["arm"]["arm_id"]) - assert np.unique(arm_ids).size == 1 - - @mark.parametrize("create_bayes_ab_payload", ["base_normal"], indirect=True) - def test_update_observation( - self, - client: TestClient, - create_bayes_abs: list, - create_bayes_ab_payload: dict, - workspace_api_key: str, - ) -> None: - id = create_bayes_abs[0]["experiment_id"] - - # First, get a draw - response = client.get( - f"/bayes_ab/{id}/draw", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - assert response.status_code == 200 - draw_id = response.json()["draw_id"] - - # Then update with an observation - response = client.put( - f"/bayes_ab/{id}/{draw_id}/0.5", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - assert response.status_code == 200 - - # Test that we can't update the same draw twice - response = client.put( - f"/bayes_ab/{id}/{draw_id}/0.5", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - assert response.status_code == 400 - - @mark.parametrize("create_bayes_ab_payload", ["base_normal"], indirect=True) - def test_get_outcomes( - self, - client: TestClient, - create_bayes_abs: list, - create_bayes_ab_payload: dict, - workspace_api_key: str, - ) -> None: - id = create_bayes_abs[0]["experiment_id"] - - # First, get a draw - response = client.get( - f"/bayes_ab/{id}/draw", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - assert response.status_code == 200 - draw_id = response.json()["draw_id"] - - # Then update with an observation - response = client.put( - f"/bayes_ab/{id}/{draw_id}/0.5", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - assert response.status_code == 200 - - # Get outcomes - response = client.get( - f"/bayes_ab/{id}/outcomes", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - assert response.status_code == 200 - assert len(response.json()) == 1 - - @mark.parametrize("create_bayes_ab_payload", ["base_normal"], indirect=True) - def test_get_arms( - self, - client: TestClient, - create_bayes_abs: list, - create_bayes_ab_payload: dict, - workspace_api_key: str, - ) -> None: - id = create_bayes_abs[0]["experiment_id"] - - # First, get a draw - response = client.get( - f"/bayes_ab/{id}/draw", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - assert response.status_code == 200 - draw_id = response.json()["draw_id"] - - # Then update with an observation - response = client.put( - f"/bayes_ab/{id}/{draw_id}/0.5", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - assert response.status_code == 200 - - # Get arms - response = client.get( - f"/bayes_ab/{id}/arms", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - assert response.status_code == 200 - assert len(response.json()) == 2 - - -class TestNotifications: - @fixture() - def create_bayes_ab_payload(self, request: FixtureRequest) -> dict: - payload: dict = copy.deepcopy(base_normal_payload) - payload["arms"] = list(payload["arms"]) - - match request.param: - case "base": - pass - case "daysElapsed_only": - payload["notifications"]["onTrialCompletion"] = False - payload["notifications"]["onDaysElapsed"] = True - case "trialCompletion_only": - payload["notifications"]["onTrialCompletion"] = True - case "percentBetter_only": - payload["notifications"]["onTrialCompletion"] = False - payload["notifications"]["onPercentBetter"] = True - case "all_notifications": - payload["notifications"]["onDaysElapsed"] = True - payload["notifications"]["onPercentBetter"] = True - case "no_notifications": - payload["notifications"]["onTrialCompletion"] = False - case "daysElapsed_missing": - payload["notifications"]["daysElapsed"] = 0 - payload["notifications"]["onDaysElapsed"] = True - case "trialCompletion_missing": - payload["notifications"]["numberOfTrials"] = 0 - payload["notifications"]["onTrialCompletion"] = True - case "percentBetter_missing": - payload["notifications"]["percentBetterThreshold"] = 0 - payload["notifications"]["onPercentBetter"] = True - case _: - raise ValueError("Invalid parameter") - - return payload - - @mark.parametrize( - "create_bayes_ab_payload, expected_response", - [ - ("base", 200), - ("daysElapsed_only", 200), - ("trialCompletion_only", 200), - ("percentBetter_only", 200), - ("all_notifications", 200), - ("no_notifications", 200), - ("daysElapsed_missing", 422), - ("trialCompletion_missing", 422), - ("percentBetter_missing", 422), - ], - indirect=["create_bayes_ab_payload"], - ) - def test_notifications( - self, - client: TestClient, - admin_token: str, - create_bayes_ab_payload: dict, - expected_response: int, - clean_bayes_ab: None, - ) -> None: - response = client.post( - "/bayes_ab", - json=create_bayes_ab_payload, - headers={"Authorization": f"Bearer {admin_token}"}, - ) - - assert response.status_code == expected_response diff --git a/backend/tests/test_cmabs.py b/backend/tests/test_cmabs.py deleted file mode 100644 index d9b6ed0..0000000 --- a/backend/tests/test_cmabs.py +++ /dev/null @@ -1,452 +0,0 @@ -import copy -import os -from typing import Generator - -import numpy as np -from fastapi.testclient import TestClient -from pytest import FixtureRequest, fixture, mark -from sqlalchemy.orm import Session - -from backend.app.contextual_mab.models import ( - ContextDB, - ContextualArmDB, - ContextualBanditDB, -) -from backend.app.models import NotificationsDB - -base_normal_payload = { - "name": "Test", - "description": "Test description", - "prior_type": "normal", - "reward_type": "real-valued", - "sticky_assignment": False, - "arms": [ - { - "name": "arm 1", - "description": "arm 1 description", - "mu_init": 0, - "sigma_init": 1, - }, - { - "name": "arm 2", - "description": "arm 2 description", - "mu_init": 0, - "sigma_init": 1, - }, - ], - "contexts": [ - { - "name": "Context 1", - "description": "context 1 description", - "value_type": "binary", - }, - { - "name": "Context 2", - "description": "context 2 description", - "value_type": "real-valued", - }, - ], - "notifications": { - "onTrialCompletion": True, - "numberOfTrials": 2, - "onDaysElapsed": False, - "daysElapsed": 3, - "onPercentBetter": False, - "percentBetterThreshold": 5, - }, -} - -base_binary_normal_payload = base_normal_payload.copy() -base_binary_normal_payload["reward_type"] = "binary" - - -@fixture -def admin_token(client: TestClient) -> str: - response = client.post( - "/login", - data={ - "username": os.environ.get("ADMIN_USERNAME", ""), - "password": os.environ.get("ADMIN_PASSWORD", ""), - }, - ) - token = response.json()["access_token"] - return token - - -@fixture -def clean_cmabs(db_session: Session) -> Generator: - yield - db_session.query(NotificationsDB).delete() - db_session.query(ContextualArmDB).delete() - db_session.query(ContextDB).delete() - db_session.query(ContextualBanditDB).delete() - db_session.commit() - - -class TestCMab: - @fixture - def create_cmab_payload(self, request: FixtureRequest) -> dict: - payload_normal: dict = copy.deepcopy(base_normal_payload) - payload_normal["arms"] = list(payload_normal["arms"]) - payload_normal["contexts"] = list(payload_normal["contexts"]) - - payload_binary_normal: dict = copy.deepcopy(base_binary_normal_payload) - payload_binary_normal["arms"] = list(payload_binary_normal["arms"]) - payload_binary_normal["contexts"] = list(payload_binary_normal["contexts"]) - - if request.param == "base_normal": - return payload_normal - if request.param == "base_binary_normal": - return payload_binary_normal - if request.param == "one_arm": - payload_normal["arms"].pop() - return payload_normal - if request.param == "no_notifications": - payload_normal["notifications"]["onTrialCompletion"] = False - return payload_normal - if request.param == "invalid_prior": - payload_normal["prior_type"] = "beta" - return payload_normal - if request.param == "invalid_reward": - payload_normal["reward_type"] = "invalid" - return payload_normal - if request.param == "invalid_sigma": - payload_normal["arms"][0]["sigma_init"] = 0 - return payload_normal - if request.param == "with_sticky_assignment": - payload_normal["sticky_assignment"] = True - return payload_normal - - else: - raise ValueError("Invalid parameter") - - @mark.parametrize( - "create_cmab_payload, expected_response", - [ - ("base_normal", 200), - ("base_binary_normal", 200), - ("one_arm", 422), - ("no_notifications", 200), - ("invalid_prior", 422), - ("invalid_reward", 422), - ("invalid_sigma", 422), - ], - indirect=["create_cmab_payload"], - ) - def test_create_cmab( - self, - create_cmab_payload: dict, - client: TestClient, - expected_response: int, - admin_token: str, - clean_cmabs: None, - ) -> None: - response = client.post( - "/contextual_mab", - json=create_cmab_payload, - headers={"Authorization": f"Bearer {admin_token}"}, - ) - - assert response.status_code == expected_response - - @fixture - def create_cmabs( - self, - client: TestClient, - admin_token: str, - request: FixtureRequest, - create_cmab_payload: dict, - ) -> Generator: - cmabs = [] - n_cmabs = request.param if hasattr(request, "param") else 1 - for _ in range(n_cmabs): - response = client.post( - "/contextual_mab", - json=create_cmab_payload, - headers={"Authorization": f"Bearer {admin_token}"}, - ) - cmabs.append(response.json()) - yield cmabs - for cmab in cmabs: - client.delete( - f"/contextual_mab/{cmab['experiment_id']}", - headers={"Authorization": f"Bearer {admin_token}"}, - ) - - @mark.parametrize( - "create_cmabs, n_expected, create_cmab_payload", - [(0, 0, "base_normal"), (2, 2, "base_normal"), (5, 5, "base_normal")], - indirect=["create_cmabs", "create_cmab_payload"], - ) - def test_get_all_cmabs( - self, - client: TestClient, - admin_token: str, - n_expected: int, - create_cmab_payload: dict, - create_cmabs: list, - ) -> None: - response = client.get( - "/contextual_mab", headers={"Authorization": f"Bearer {admin_token}"} - ) - assert response.status_code == 200 - assert len(response.json()) == n_expected - - @mark.parametrize( - "create_cmabs, expected_response, create_cmab_payload", - [(0, 404, "base_normal"), (2, 200, "base_normal")], - indirect=["create_cmabs", "create_cmab_payload"], - ) - def test_get_cmab( - self, - client: TestClient, - admin_token: str, - create_cmab_payload: dict, - create_cmabs: list, - expected_response: int, - ) -> None: - id = create_cmabs[0]["experiment_id"] if create_cmabs else 999 - - response = client.get( - f"/contextual_mab/{id}", headers={"Authorization": f"Bearer {admin_token}"} - ) - assert response.status_code == expected_response - - @mark.parametrize("create_cmab_payload", ["base_normal"], indirect=True) - def test_draw_arm_draw_id_provided( - self, - client: TestClient, - create_cmabs: list, - create_cmab_payload: dict, - workspace_api_key: str, - ) -> None: - id = create_cmabs[0]["experiment_id"] - response = client.post( - f"/contextual_mab/{id}/draw", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - params={"draw_id": "test_draw_id"}, - json=[ - {"context_id": 1, "context_value": 0}, - {"context_id": 2, "context_value": 0.5}, - ], - ) - assert response.status_code == 200 - assert response.json()["draw_id"] == "test_draw_id" - - @mark.parametrize("create_cmab_payload", ["base_normal"], indirect=True) - def test_draw_arm_no_draw_id_provided( - self, - client: TestClient, - create_cmabs: list, - create_cmab_payload: dict, - workspace_api_key: str, - ) -> None: - id = create_cmabs[0]["experiment_id"] - response = client.post( - f"/contextual_mab/{id}/draw", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - json=[ - {"context_id": 1, "context_value": 0}, - {"context_id": 2, "context_value": 0.5}, - ], - ) - assert response.status_code == 200 - assert len(response.json()["draw_id"]) == 36 - - @mark.parametrize( - "create_cmab_payload, client_id, expected_response", - [ - ("with_sticky_assignment", None, 400), - ("with_sticky_assignment", "test_client_id", 200), - ], - indirect=["create_cmab_payload"], - ) - def test_draw_arm_sticky_assignment_client_id_provided( - self, - client: TestClient, - create_cmabs: list, - create_cmab_payload: dict, - client_id: str | None, - expected_response: int, - workspace_api_key: str, - ) -> None: - id = create_cmabs[0]["experiment_id"] - url = f"/contextual_mab/{id}/draw" - if client_id: - url += f"?client_id={client_id}" - - response = client.post( - url, - headers={"Authorization": f"Bearer {workspace_api_key}"}, - json=[ - {"context_id": 1, "context_value": 0}, - {"context_id": 2, "context_value": 0.5}, - ], - ) - assert response.status_code == expected_response - - @mark.parametrize("create_cmab_payload", ["with_sticky_assignment"], indirect=True) - def test_draw_arm_with_sticky_assignment( - self, - client: TestClient, - create_cmabs: list, - create_cmab_payload: dict, - workspace_api_key: str, - ) -> None: - id = create_cmabs[0]["experiment_id"] - arm_ids = [] - - for _ in range(10): - response = client.post( - f"/contextual_mab/{id}/draw?client_id=123", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - json=[ - {"context_id": 1, "context_value": 0}, - {"context_id": 2, "context_value": 1}, - ], - ) - arm_ids.append(response.json()["arm"]["arm_id"]) - - assert np.unique(arm_ids).size == 1 - - @mark.parametrize("create_cmab_payload", ["base_normal"], indirect=True) - def test_one_outcome_per_draw( - self, - client: TestClient, - create_cmabs: list, - create_cmab_payload: dict, - workspace_api_key: str, - ) -> None: - id = create_cmabs[0]["experiment_id"] - response = client.post( - f"/contextual_mab/{id}/draw", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - json=[ - {"context_id": 1, "context_value": 0}, - {"context_id": 2, "context_value": 0.5}, - ], - ) - assert response.status_code == 200 - draw_id = response.json()["draw_id"] - - response = client.put( - f"/contextual_mab/{id}/{draw_id}/1", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - - assert response.status_code == 200 - - response = client.put( - f"/contextual_mab/{id}/{draw_id}/1", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - - assert response.status_code == 400 - - @mark.parametrize( - "n_draws, create_cmab_payload", - [(0, "base_normal"), (1, "base_normal"), (5, "base_normal")], - indirect=["create_cmab_payload"], - ) - def test_get_outcomes( - self, - client: TestClient, - create_cmabs: list, - n_draws: int, - create_cmab_payload: dict, - workspace_api_key: str, - ) -> None: - id = create_cmabs[0]["experiment_id"] - - for _ in range(n_draws): - response = client.post( - f"/contextual_mab/{id}/draw", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - json=[ - {"context_id": 1, "context_value": 0}, - {"context_id": 2, "context_value": 0.5}, - ], - ) - assert response.status_code == 200 - draw_id = response.json()["draw_id"] - response = client.put( - f"/contextual_mab/{id}/{draw_id}/1", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - - response = client.get( - f"/contextual_mab/{id}/outcomes", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - - assert response.status_code == 200 - assert len(response.json()) == n_draws - - -class TestNotifications: - @fixture() - def create_cmab_payload(self, request: FixtureRequest) -> dict: - payload: dict = copy.deepcopy(base_normal_payload) - payload["arms"] = list(payload["arms"]) - payload["contexts"] = list(payload["contexts"]) - - match request.param: - case "base": - pass - case "daysElapsed_only": - payload["notifications"]["onTrialCompletion"] = False - payload["notifications"]["onDaysElapsed"] = True - case "trialCompletion_only": - payload["notifications"]["onTrialCompletion"] = True - case "percentBetter_only": - payload["notifications"]["onTrialCompletion"] = False - payload["notifications"]["onPercentBetter"] = True - case "all_notifications": - payload["notifications"]["onDaysElapsed"] = True - payload["notifications"]["onPercentBetter"] = True - case "no_notifications": - payload["notifications"]["onTrialCompletion"] = False - case "daysElapsed_missing": - payload["notifications"]["daysElapsed"] = 0 - payload["notifications"]["onDaysElapsed"] = True - case "trialCompletion_missing": - payload["notifications"]["numberOfTrials"] = 0 - payload["notifications"]["onTrialCompletion"] = True - case "percentBetter_missing": - payload["notifications"]["percentBetterThreshold"] = 0 - payload["notifications"]["onPercentBetter"] = True - case _: - raise ValueError("Invalid parameter") - - return payload - - @mark.parametrize( - "create_cmab_payload, expected_response", - [ - ("base", 200), - ("daysElapsed_only", 200), - ("trialCompletion_only", 200), - ("percentBetter_only", 200), - ("all_notifications", 200), - ("no_notifications", 200), - ("daysElapsed_missing", 422), - ("trialCompletion_missing", 422), - ("percentBetter_missing", 422), - ], - indirect=["create_cmab_payload"], - ) - def test_notifications( - self, - client: TestClient, - admin_token: str, - create_cmab_payload: dict, - expected_response: int, - clean_cmabs: None, - ) -> None: - response = client.post( - "/contextual_mab", - json=create_cmab_payload, - headers={"Authorization": f"Bearer {admin_token}"}, - ) - - assert response.status_code == expected_response diff --git a/backend/tests/test_mabs.py b/backend/tests/test_mabs.py deleted file mode 100644 index e65ccb6..0000000 --- a/backend/tests/test_mabs.py +++ /dev/null @@ -1,463 +0,0 @@ -import copy -import os -from typing import Generator - -import numpy as np -from fastapi.testclient import TestClient -from pytest import FixtureRequest, fixture, mark -from sqlalchemy.orm import Session - -from backend.app.mab.models import MABArmDB, MultiArmedBanditDB -from backend.app.models import NotificationsDB - -base_beta_binom_payload = { - "name": "Test", - "description": "Test description", - "prior_type": "beta", - "reward_type": "binary", - "sticky_assignment": False, - "arms": [ - { - "name": "arm 1", - "description": "arm 1 description", - "alpha_init": 5, - "beta_init": 1, - }, - { - "name": "arm 2", - "description": "arm 2 description", - "alpha_init": 1, - "beta_init": 4, - }, - ], - "notifications": { - "onTrialCompletion": True, - "numberOfTrials": 2, - "onDaysElapsed": False, - "daysElapsed": 3, - "onPercentBetter": False, - "percentBetterThreshold": 5, - }, -} - -base_normal_payload = base_beta_binom_payload.copy() -base_normal_payload["prior_type"] = "normal" -base_normal_payload["reward_type"] = "real-valued" -base_normal_payload["arms"] = [ - { - "name": "arm 1", - "description": "arm 1 description", - "mu_init": 2, - "sigma_init": 3, - }, - { - "name": "arm 2", - "description": "arm 2 description", - "mu_init": 3, - "sigma_init": 7, - }, -] - - -@fixture -def admin_token(client: TestClient) -> str: - response = client.post( - "/login", - data={ - "username": os.environ.get("ADMIN_USERNAME", ""), - "password": os.environ.get("ADMIN_PASSWORD", ""), - }, - ) - token = response.json()["access_token"] - return token - - -@fixture -def clean_mabs(db_session: Session) -> Generator: - yield - db_session.query(NotificationsDB).delete() - db_session.query(MABArmDB).delete() - db_session.query(MultiArmedBanditDB).delete() - db_session.commit() - - -class TestMab: - @fixture - def create_mab_payload(self, request: FixtureRequest) -> dict: - payload_beta_binom: dict = copy.deepcopy(base_beta_binom_payload) - payload_beta_binom["arms"] = list(payload_beta_binom["arms"]) - - payload_normal: dict = copy.deepcopy(base_normal_payload) - payload_normal["arms"] = list(payload_normal["arms"]) - - if request.param == "base_beta_binom": - return payload_beta_binom - if request.param == "base_normal": - return payload_normal - if request.param == "one_arm": - payload_beta_binom["arms"].pop() - return payload_beta_binom - if request.param == "no_notifications": - payload_beta_binom["notifications"]["onTrialCompletion"] = False - return payload_beta_binom - if request.param == "invalid_prior": - payload_beta_binom["prior_type"] = "invalid" - return payload_beta_binom - if request.param == "invalid_reward": - payload_beta_binom["reward_type"] = "invalid" - return payload_beta_binom - if request.param == "invalid_alpha": - payload_beta_binom["arms"][0]["alpha_init"] = -1 - return payload_beta_binom - if request.param == "invalid_beta": - payload_beta_binom["arms"][0]["beta_init"] = -1 - return payload_beta_binom - if request.param == "invalid_combo_1": - payload_beta_binom["prior_type"] = "normal" - return payload_beta_binom - if request.param == "invalid_combo_2": - payload_beta_binom["reward_type"] = "continuous" - return payload_beta_binom - if request.param == "incorrect_params": - payload_beta_binom["arms"][0].pop("alpha_init") - return payload_beta_binom - if request.param == "invalid_sigma": - payload_normal["arms"][0]["sigma_init"] = 0.0 - return payload_normal - if request.param == "with_sticky_assignment": - payload_beta_binom["sticky_assignment"] = True - return payload_beta_binom - else: - raise ValueError("Invalid parameter") - - @mark.parametrize( - "create_mab_payload, expected_response", - [ - ("base_beta_binom", 200), - ("base_normal", 200), - ("one_arm", 422), - ("no_notifications", 200), - ("invalid_prior", 422), - ("invalid_reward", 422), - ("invalid_alpha", 422), - ("invalid_beta", 422), - ("invalid_combo_1", 422), - ("invalid_combo_2", 422), - ("incorrect_params", 422), - ("invalid_sigma", 422), - ], - indirect=["create_mab_payload"], - ) - def test_create_mab( - self, - create_mab_payload: dict, - client: TestClient, - expected_response: int, - admin_token: str, - clean_mabs: None, - ) -> None: - response = client.post( - "/mab", - json=create_mab_payload, - headers={"Authorization": f"Bearer {admin_token}"}, - ) - - assert response.status_code == expected_response - - @fixture - def create_mabs( - self, - client: TestClient, - admin_token: str, - request: FixtureRequest, - create_mab_payload: dict, - ) -> Generator: - mabs = [] - n_mabs = request.param if hasattr(request, "param") else 1 - for _ in range(n_mabs): - response = client.post( - "/mab", - json=create_mab_payload, - headers={"Authorization": f"Bearer {admin_token}"}, - ) - mabs.append(response.json()) - yield mabs - for mab in mabs: - client.delete( - f"/mab/{mab['experiment_id']}", - headers={"Authorization": f"Bearer {admin_token}"}, - ) - - @mark.parametrize( - "create_mabs, create_mab_payload, n_expected", - [ - (0, "base_beta_binom", 0), - (2, "base_beta_binom", 2), - (5, "base_beta_binom", 5), - ], - indirect=["create_mabs", "create_mab_payload"], - ) - def test_get_all_mabs( - self, - client: TestClient, - admin_token: str, - n_expected: int, - create_mabs: list, - create_mab_payload: dict, - ) -> None: - response = client.get( - "/mab", headers={"Authorization": f"Bearer {admin_token}"} - ) - assert response.status_code == 200 - assert len(response.json()) == n_expected - - @mark.parametrize( - "create_mabs, create_mab_payload, expected_response", - [(0, "base_beta_binom", 404), (2, "base_beta_binom", 200)], - indirect=["create_mabs", "create_mab_payload"], - ) - def test_get_mab( - self, - client: TestClient, - admin_token: str, - create_mabs: list, - create_mab_payload: dict, - expected_response: int, - ) -> None: - id = create_mabs[0]["experiment_id"] if create_mabs else 999 - - response = client.get( - f"/mab/{id}/", headers={"Authorization": f"Bearer {admin_token}"} - ) - assert response.status_code == expected_response - - @mark.parametrize("create_mab_payload", ["base_beta_binom"], indirect=True) - def test_draw_arm_draw_id_provided( - self, - client: TestClient, - create_mabs: list, - create_mab_payload: dict, - workspace_api_key: str, - ) -> None: - id = create_mabs[0]["experiment_id"] - response = client.get( - f"/mab/{id}/draw", - params={"draw_id": "test_draw"}, - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - assert response.status_code == 200 - assert response.json()["draw_id"] == "test_draw" - - @mark.parametrize("create_mab_payload", ["base_beta_binom"], indirect=True) - def test_draw_arm_no_draw_id_provided( - self, - client: TestClient, - create_mabs: list, - create_mab_payload: dict, - workspace_api_key: str, - ) -> None: - id = create_mabs[0]["experiment_id"] - response = client.get( - f"/mab/{id}/draw", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - assert response.status_code == 200 - assert len(response.json()["draw_id"]) == 36 - - @mark.parametrize( - "create_mab_payload, client_id, expected_response", - [ - ("with_sticky_assignment", None, 400), - ("with_sticky_assignment", "test_client_id", 200), - ], - indirect=["create_mab_payload"], - ) - def test_draw_arm_sticky_assignment_with_client_id( - self, - client: TestClient, - admin_token: str, - create_mab_payload: dict, - create_mabs: list, - client_id: str | None, - expected_response: int, - workspace_api_key: str, - ) -> None: - mabs = create_mabs - id = mabs[0]["experiment_id"] - response = client.get( - f"/mab/{id}/draw{'?client_id=' + client_id if client_id else ''}", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - assert response.status_code == expected_response - - @mark.parametrize("create_mab_payload", ["with_sticky_assignment"], indirect=True) - def test_draw_arm_sticky_assignment_client_id_provided( - self, - client: TestClient, - admin_token: str, - create_mab_payload: dict, - create_mabs: list, - workspace_api_key: str, - ) -> None: - mabs = create_mabs - id = mabs[0]["experiment_id"] - response = client.get( - f"/mab/{id}/draw?client_id=123", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - assert response.status_code == 200 - - @mark.parametrize("create_mab_payload", ["with_sticky_assignment"], indirect=True) - def test_draw_arm_sticky_assignment_similar_arms( - self, - client: TestClient, - admin_token: str, - create_mab_payload: dict, - create_mabs: list, - workspace_api_key: str, - ) -> None: - mabs = create_mabs - id = mabs[0]["experiment_id"] - - arm_ids = [] - for _ in range(10): - response = client.get( - f"/mab/{id}/draw?client_id=123", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - arm_ids.append(response.json()["arm"]["arm_id"]) - assert np.unique(arm_ids).size == 1 - - @mark.parametrize("create_mab_payload", ["base_beta_binom"], indirect=True) - def test_one_outcome_per_draw( - self, - client: TestClient, - create_mabs: list, - create_mab_payload: dict, - workspace_api_key: str, - ) -> None: - id = create_mabs[0]["experiment_id"] - response = client.get( - f"/mab/{id}/draw", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - assert response.status_code == 200 - draw_id = response.json()["draw_id"] - - response = client.put( - f"/mab/{id}/{draw_id}/1", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - - assert response.status_code == 200 - - response = client.put( - f"/mab/{id}/{draw_id}/1", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - - assert response.status_code == 400 - - @mark.parametrize( - "n_draws, create_mab_payload", - [(0, "base_beta_binom"), (1, "base_beta_binom"), (5, "base_beta_binom")], - indirect=["create_mab_payload"], - ) - def test_get_outcomes( - self, - client: TestClient, - create_mabs: list, - n_draws: int, - create_mab_payload: dict, - workspace_api_key: str, - ) -> None: - id = create_mabs[0]["experiment_id"] - - for _ in range(n_draws): - response = client.get( - f"/mab/{id}/draw", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - assert response.status_code == 200 - draw_id = response.json()["draw_id"] - # put outcomes - response = client.put( - f"/mab/{id}/{draw_id}/1", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - - response = client.get( - f"/mab/{id}/outcomes", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - - assert response.status_code == 200 - assert len(response.json()) == n_draws - - -class TestNotifications: - @fixture() - def create_mab_payload(self, request: FixtureRequest) -> dict: - payload: dict = copy.deepcopy(base_beta_binom_payload) - payload["arms"] = list(payload["arms"]) - - match request.param: - case "base": - pass - case "daysElapsed_only": - payload["notifications"]["onTrialCompletion"] = False - payload["notifications"]["onDaysElapsed"] = True - case "trialCompletion_only": - payload["notifications"]["onTrialCompletion"] = True - case "percentBetter_only": - payload["notifications"]["onTrialCompletion"] = False - payload["notifications"]["onPercentBetter"] = True - case "all_notifications": - payload["notifications"]["onDaysElapsed"] = True - payload["notifications"]["onPercentBetter"] = True - case "no_notifications": - payload["notifications"]["onTrialCompletion"] = False - case "daysElapsed_missing": - payload["notifications"]["daysElapsed"] = 0 - payload["notifications"]["onDaysElapsed"] = True - case "trialCompletion_missing": - payload["notifications"]["numberOfTrials"] = 0 - payload["notifications"]["onTrialCompletion"] = True - case "percentBetter_missing": - payload["notifications"]["percentBetterThreshold"] = 0 - payload["notifications"]["onPercentBetter"] = True - case _: - raise ValueError("Invalid parameter") - - return payload - - @mark.parametrize( - "create_mab_payload, expected_response", - [ - ("base", 200), - ("daysElapsed_only", 200), - ("trialCompletion_only", 200), - ("percentBetter_only", 200), - ("all_notifications", 200), - ("no_notifications", 200), - ("daysElapsed_missing", 422), - ("trialCompletion_missing", 422), - ("percentBetter_missing", 422), - ], - indirect=["create_mab_payload"], - ) - def test_notifications( - self, - client: TestClient, - admin_token: str, - create_mab_payload: dict, - expected_response: int, - clean_mabs: None, - ) -> None: - response = client.post( - "/mab", - json=create_mab_payload, - headers={"Authorization": f"Bearer {admin_token}"}, - ) - - assert response.status_code == expected_response From 4845a449ff14c436ecc8367198b0471611ea7701 Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Fri, 6 Jun 2025 17:04:52 +0300 Subject: [PATCH 59/74] merge changes from refactor --- backend/add_users_to_db.py | 102 ++- backend/app/__init__.py | 7 +- backend/app/bayes_ab/__init__.py | 1 - backend/app/bayes_ab/models.py | 433 ----------- backend/app/bayes_ab/observation.py | 75 -- backend/app/bayes_ab/routers.py | 524 ------------- backend/app/bayes_ab/sampling_utils.py | 126 --- backend/app/bayes_ab/schemas.py | 145 ---- backend/app/contextual_mab/__init__.py | 1 - backend/app/contextual_mab/models.py | 483 ------------ backend/app/contextual_mab/observation.py | 126 --- backend/app/contextual_mab/routers.py | 395 ---------- backend/app/contextual_mab/sampling_utils.py | 172 ---- backend/app/contextual_mab/schemas.py | 268 ------- backend/app/experiments/dependencies.py | 269 +++++++ backend/app/experiments/models.py | 736 ++++++++++++++++++ backend/app/experiments/routers.py | 489 ++++++++++++ backend/app/experiments/sampling_utils.py | 326 ++++++++ backend/app/experiments/schemas.py | 548 +++++++++++++ backend/app/mab/__init__.py | 1 - backend/app/mab/models.py | 419 ---------- backend/app/mab/observation.py | 94 --- backend/app/mab/routers.py | 357 --------- backend/app/mab/sampling_utils.py | 138 ---- backend/app/mab/schemas.py | 262 ------- backend/app/models.py | 9 +- backend/app/workspaces/models.py | 5 +- backend/jobs/auto_fail.py | 188 +---- .../275ff74c0866_add_client_id_to_draws_db.py | 30 - ..._add_tables_for_bayesian_ab_experiments.py | 66 -- ...added_first_name_and_last_name_to_users.py | 36 - .../versions/6101ba814d91_fresh_start.py | 438 +++++++++++ .../versions/9f7482ba882f_workspace_model.py | 123 --- .../ecddd830b464_remove_user_api_key.py | 70 -- .../versions/faf4228e13a3_clean_start.py | 257 ------ ...d_added_sticky_assignments_and_autofail.py | 59 -- backend/tests/pytest.ini | 5 + .../docker-compose/docker-compose-dev.yml | 12 +- deployment/docker-compose/docker-compose.yml | 2 +- .../src/app/(protected)/workspaces/page.tsx | 18 +- frontend/src/components/app-sidebar.tsx | 80 +- frontend/src/components/ui/tabs.tsx | 2 +- .../src/components/workspace-switcher.tsx | 1 + frontend/src/utils/auth.tsx | 20 +- 44 files changed, 2995 insertions(+), 4923 deletions(-) delete mode 100644 backend/app/bayes_ab/__init__.py delete mode 100644 backend/app/bayes_ab/models.py delete mode 100644 backend/app/bayes_ab/observation.py delete mode 100644 backend/app/bayes_ab/routers.py delete mode 100644 backend/app/bayes_ab/sampling_utils.py delete mode 100644 backend/app/bayes_ab/schemas.py delete mode 100644 backend/app/contextual_mab/__init__.py delete mode 100644 backend/app/contextual_mab/models.py delete mode 100644 backend/app/contextual_mab/observation.py delete mode 100644 backend/app/contextual_mab/routers.py delete mode 100644 backend/app/contextual_mab/sampling_utils.py delete mode 100644 backend/app/contextual_mab/schemas.py create mode 100644 backend/app/experiments/dependencies.py create mode 100644 backend/app/experiments/models.py create mode 100644 backend/app/experiments/routers.py create mode 100644 backend/app/experiments/sampling_utils.py create mode 100644 backend/app/experiments/schemas.py delete mode 100644 backend/app/mab/__init__.py delete mode 100644 backend/app/mab/models.py delete mode 100644 backend/app/mab/observation.py delete mode 100644 backend/app/mab/routers.py delete mode 100644 backend/app/mab/sampling_utils.py delete mode 100644 backend/app/mab/schemas.py delete mode 100644 backend/migrations/versions/275ff74c0866_add_client_id_to_draws_db.py delete mode 100644 backend/migrations/versions/28adf347e68d_add_tables_for_bayesian_ab_experiments.py delete mode 100644 backend/migrations/versions/5c15463fda65_added_first_name_and_last_name_to_users.py create mode 100644 backend/migrations/versions/6101ba814d91_fresh_start.py delete mode 100644 backend/migrations/versions/9f7482ba882f_workspace_model.py delete mode 100644 backend/migrations/versions/ecddd830b464_remove_user_api_key.py delete mode 100644 backend/migrations/versions/faf4228e13a3_clean_start.py delete mode 100644 backend/migrations/versions/feb042798cad_added_sticky_assignments_and_autofail.py create mode 100644 backend/tests/pytest.ini diff --git a/backend/add_users_to_db.py b/backend/add_users_to_db.py index 1efbe80..82235e3 100644 --- a/backend/add_users_to_db.py +++ b/backend/add_users_to_db.py @@ -1,8 +1,11 @@ import asyncio import os from datetime import datetime, timezone +from typing import Union from redis import asyncio as aioredis +from sqlalchemy import select +from sqlalchemy.orm import Session from app.config import REDIS_HOST from app.database import get_session @@ -33,7 +36,7 @@ async def async_redis_operations(key: str, value: int | None) -> None: await redis.aclose() -def run_redis_async_tasks(key: str, value: int | str) -> None: +def run_redis_async_tasks(key: str, value: Union[int, str]) -> None: """ Run asynchronous Redis operations to set the remaining API calls for a user. """ @@ -43,6 +46,103 @@ def run_redis_async_tasks(key: str, value: int | str) -> None: loop.run_until_complete(async_redis_operations(key, value_int)) +def ensure_default_workspace(db_session: Session, user_db: UserDB) -> None: + """ + Ensure that a user has a default workspace. + + Parameters + ---------- + db_session + The database session. + user_db + The user DB record. + """ + # Check if user already has a workspace + stmt = select(UserWorkspaceDB).where(UserWorkspaceDB.user_id == user_db.user_id) + result = db_session.execute(stmt) + existing_workspace = result.scalar_one_or_none() + + if existing_workspace: + logger.info( + f"User {user_db.username} already has workspace relationship: " + f"{existing_workspace.workspace_id}" + ) + # Check if any workspace is set as default + stmt = select(UserWorkspaceDB).where( + UserWorkspaceDB.user_id == user_db.user_id, + UserWorkspaceDB.default_workspace, + ) + result = db_session.execute(stmt) + default_workspace = result.scalar_one_or_none() + + if default_workspace: + logger.info( + f"User {user_db.username} already has default workspace: " + f"{default_workspace.workspace_id}" + ) + return + else: + # Set first workspace as default + existing_workspace.default_workspace = True + db_session.add(existing_workspace) + db_session.commit() + logger.info( + f"Set workspace {existing_workspace.workspace_id} as default for " + f"{user_db.username}" + ) + return + + # Create a default workspace for the user + workspace_name = f"{user_db.username}'s Workspace" + + # Check if workspace with this name already exists + stmt = select(WorkspaceDB).where(WorkspaceDB.workspace_name == workspace_name) + result = db_session.execute(stmt) + existing_workspace_db = result.scalar_one_or_none() + + if existing_workspace_db: + workspace_db = existing_workspace_db + logger.info( + f"Workspace '{workspace_name}' already exists with ID " + f"{workspace_db.workspace_id}" + ) + else: + # Create new workspace + workspace_db = WorkspaceDB( + workspace_name=workspace_name, + api_daily_quota=100, + content_quota=10, + created_datetime_utc=datetime.now(timezone.utc), + updated_datetime_utc=datetime.now(timezone.utc), + is_default=True, + hashed_api_key=get_key_hash("workspace-api-key-" + workspace_name), + api_key_first_characters="works", + api_key_updated_datetime_utc=datetime.now(timezone.utc), + api_key_rotated_by_user_id=user_db.user_id, + ) + db_session.add(workspace_db) + db_session.commit() + logger.info( + f"Created workspace '{workspace_name}' with ID {workspace_db.workspace_id}" + ) + + # Create user-workspace relationship + user_workspace = UserWorkspaceDB( + user_id=user_db.user_id, + workspace_id=workspace_db.workspace_id, + user_role=UserRoles.ADMIN, + default_workspace=True, + created_datetime_utc=datetime.now(timezone.utc), + updated_datetime_utc=datetime.now(timezone.utc), + ) + db_session.add(user_workspace) + db_session.commit() + logger.info( + f"Created workspace relationship for user {user_db.username} with workspace " + f"{workspace_db.workspace_id}" + ) + + if __name__ == "__main__": db_session = next(get_session()) diff --git a/backend/app/__init__.py b/backend/app/__init__.py index 37790a8..88e9dfd 100644 --- a/backend/app/__init__.py +++ b/backend/app/__init__.py @@ -5,8 +5,9 @@ from fastapi.middleware.cors import CORSMiddleware from redis import asyncio as aioredis -from . import auth, bayes_ab, contextual_mab, mab, messages +from . import auth, messages from .config import BACKEND_ROOT_PATH, DOMAIN, REDIS_HOST +from .experiments.routers import router as experiments_router from .users.routers import ( router as users_router, ) # to avoid circular imports @@ -56,9 +57,7 @@ def create_app() -> FastAPI: expose_headers=["*"], ) - app.include_router(mab.router) - app.include_router(contextual_mab.router) - app.include_router(bayes_ab.router) + app.include_router(experiments_router) app.include_router(auth.router) app.include_router(users_router) app.include_router(messages.router) diff --git a/backend/app/bayes_ab/__init__.py b/backend/app/bayes_ab/__init__.py deleted file mode 100644 index fa07d07..0000000 --- a/backend/app/bayes_ab/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .routers import router # noqa: F401 diff --git a/backend/app/bayes_ab/models.py b/backend/app/bayes_ab/models.py deleted file mode 100644 index 8caee04..0000000 --- a/backend/app/bayes_ab/models.py +++ /dev/null @@ -1,433 +0,0 @@ -from datetime import datetime, timezone -from typing import Sequence - -from sqlalchemy import ( - Boolean, - Float, - ForeignKey, - and_, - delete, - select, -) -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import Mapped, mapped_column, relationship - -from ..models import ( - ArmBaseDB, - DrawsBaseDB, - ExperimentBaseDB, - NotificationsDB, -) -from ..schemas import ObservationType -from .schemas import BayesianAB - - -class BayesianABDB(ExperimentBaseDB): - """ - ORM for managing experiments. - """ - - __tablename__ = "bayes_ab_experiments" - - experiment_id: Mapped[int] = mapped_column( - ForeignKey("experiments_base.experiment_id", ondelete="CASCADE"), - primary_key=True, - nullable=False, - ) - - arms: Mapped[list["BayesianABArmDB"]] = relationship( - "BayesianABArmDB", back_populates="experiment", lazy="selectin" - ) - - draws: Mapped[list["BayesianABDrawDB"]] = relationship( - "BayesianABDrawDB", back_populates="experiment", lazy="joined" - ) - - __mapper_args__ = {"polymorphic_identity": "bayes_ab_experiments"} - - def to_dict(self) -> dict: - """ - Convert the ORM object to a dictionary. - """ - return { - "experiment_id": self.experiment_id, - "user_id": self.user_id, - "workspace_id": self.workspace_id, - "name": self.name, - "description": self.description, - "sticky_assignment": self.sticky_assignment, - "auto_fail": self.auto_fail, - "auto_fail_value": self.auto_fail_value, - "auto_fail_unit": self.auto_fail_unit, - "created_datetime_utc": self.created_datetime_utc, - "is_active": self.is_active, - "n_trials": self.n_trials, - "arms": [arm.to_dict() for arm in self.arms], - "prior_type": self.prior_type, - "reward_type": self.reward_type, - } - - -class BayesianABArmDB(ArmBaseDB): - """ - ORM for managing arms. - """ - - __tablename__ = "bayes_ab_arms" - - arm_id: Mapped[int] = mapped_column( - ForeignKey("arms_base.arm_id", ondelete="CASCADE"), - primary_key=True, - nullable=False, - ) - - # prior variables for AB arms - mu_init: Mapped[float] = mapped_column(Float, nullable=False) - sigma_init: Mapped[float] = mapped_column(Float, nullable=False) - mu: Mapped[float] = mapped_column(Float, nullable=False) - sigma: Mapped[float] = mapped_column(Float, nullable=False) - is_treatment_arm: Mapped[bool] = mapped_column( - Boolean, nullable=False, default=False - ) - - experiment: Mapped[BayesianABDB] = relationship( - "BayesianABDB", back_populates="arms", lazy="joined" - ) - draws: Mapped[list["BayesianABDrawDB"]] = relationship( - "BayesianABDrawDB", back_populates="arm", lazy="joined" - ) - - __mapper_args__ = {"polymorphic_identity": "bayes_ab_arms"} - - def to_dict(self) -> dict: - """ - Convert the ORM object to a dictionary. - """ - return { - "arm_id": self.arm_id, - "name": self.name, - "description": self.description, - "mu_init": self.mu_init, - "sigma_init": self.sigma_init, - "mu": self.mu, - "sigma": self.sigma, - "is_treatment_arm": self.is_treatment_arm, - "draws": [draw.to_dict() for draw in self.draws], - } - - -class BayesianABDrawDB(DrawsBaseDB): - """ - ORM for managing draws of AB experiment. - """ - - __tablename__ = "bayes_ab_draws" - - draw_id: Mapped[str] = mapped_column( # Changed from int to str - ForeignKey("draws_base.draw_id", ondelete="CASCADE"), - primary_key=True, - nullable=False, - ) - - arm: Mapped[BayesianABArmDB] = relationship( - "BayesianABArmDB", back_populates="draws", lazy="joined" - ) - experiment: Mapped[BayesianABDB] = relationship( - "BayesianABDB", back_populates="draws", lazy="joined" - ) - - __mapper_args__ = {"polymorphic_identity": "bayes_ab_draws"} - - def to_dict(self) -> dict: - """ - Convert the ORM object to a dictionary. - """ - return { - "draw_id": self.draw_id, - "client_id": self.client_id, - "draw_datetime_utc": self.draw_datetime_utc, - "arm_id": self.arm_id, - "experiment_id": self.experiment_id, - "user_id": self.user_id, - "reward": self.reward, - "observation_type": self.observation_type, - "observed_datetime_utc": self.observed_datetime_utc, - } - - -async def save_bayes_ab_to_db( - ab_experiment: BayesianAB, - user_id: int, - workspace_id: int, - asession: AsyncSession, -) -> BayesianABDB: - """ - Save the A/B experiment to the database. - """ - arms = [ - BayesianABArmDB( - name=arm.name, - description=arm.description, - mu_init=arm.mu_init, - sigma_init=arm.sigma_init, - n_outcomes=arm.n_outcomes, - is_treatment_arm=arm.is_treatment_arm, - mu=arm.mu_init, - sigma=arm.sigma_init, - user_id=user_id, - ) - for arm in ab_experiment.arms - ] - - bayes_ab_db = BayesianABDB( - name=ab_experiment.name, - description=ab_experiment.description, - user_id=user_id, - workspace_id=workspace_id, - is_active=ab_experiment.is_active, - created_datetime_utc=datetime.now(timezone.utc), - n_trials=0, - arms=arms, - sticky_assignment=ab_experiment.sticky_assignment, - auto_fail=ab_experiment.auto_fail, - auto_fail_value=ab_experiment.auto_fail_value, - auto_fail_unit=ab_experiment.auto_fail_unit, - prior_type=ab_experiment.prior_type.value, - reward_type=ab_experiment.reward_type.value, - ) - - asession.add(bayes_ab_db) - await asession.commit() - await asession.refresh(bayes_ab_db) - - return bayes_ab_db - - -async def get_all_bayes_ab_experiments( - workspace_id: int, - asession: AsyncSession, -) -> Sequence[BayesianABDB]: - """ - Get all the A/B experiments from the database for a specific workspace. - """ - stmt = ( - select(BayesianABDB) - .where(BayesianABDB.workspace_id == workspace_id) - .order_by(BayesianABDB.experiment_id) - ) - result = await asession.execute(stmt) - return result.unique().scalars().all() - - -async def get_bayes_ab_experiment_by_id( - experiment_id: int, - workspace_id: int, - asession: AsyncSession, -) -> BayesianABDB | None: - """ - Get the A/B experiment by id from a specific workspace. - """ - conditions = [ - BayesianABDB.workspace_id == workspace_id, - BayesianABDB.experiment_id == experiment_id, - ] - - stmt = select(BayesianABDB).where(and_(*conditions)) - result = await asession.execute(stmt) - return result.unique().scalar_one_or_none() - - -async def delete_bayes_ab_experiment_by_id( - experiment_id: int, - workspace_id: int, - asession: AsyncSession, -) -> None: - """ - Delete the A/B experiment by id from a specific workspace. - """ - stmt = delete(BayesianABDB).where( - and_( - BayesianABDB.workspace_id == workspace_id, - BayesianABDB.experiment_id == experiment_id, - BayesianABDB.experiment_id == ExperimentBaseDB.experiment_id, - ) - ) - await asession.execute(stmt) - - stmt = delete(NotificationsDB).where( - NotificationsDB.experiment_id == experiment_id, - ) - await asession.execute(stmt) - - stmt = delete(BayesianABDrawDB).where( - and_( - BayesianABDrawDB.draw_id == DrawsBaseDB.draw_id, - BayesianABDrawDB.experiment_id == experiment_id, - ) - ) - await asession.execute(stmt) - - stmt = delete(BayesianABArmDB).where( - and_( - BayesianABArmDB.arm_id == ArmBaseDB.arm_id, - BayesianABArmDB.experiment_id == experiment_id, - ) - ) - await asession.execute(stmt) - - await asession.commit() - return None - - -async def save_bayes_ab_observation_to_db( - draw: BayesianABDrawDB, - reward: float, - asession: AsyncSession, - observation_type: ObservationType = ObservationType.AUTO, -) -> BayesianABDrawDB: - """ - Save the A/B observation to the database. - """ - draw.reward = reward - draw.observed_datetime_utc = datetime.now(timezone.utc) - draw.observation_type = observation_type - - await asession.commit() - await asession.refresh(draw) - - return draw - - -async def save_bayes_ab_draw_to_db( - experiment_id: int, - arm_id: int, - draw_id: str, - client_id: str | None, - user_id: int | None, - asession: AsyncSession, - workspace_id: int | None, -) -> BayesianABDrawDB: - """ - Save a draw to the database - """ - # If user_id is not provided but needed, get it from the experiment - if user_id is None and workspace_id is not None: - experiment = await get_bayes_ab_experiment_by_id( - experiment_id=experiment_id, - workspace_id=workspace_id, - asession=asession, - ) - if experiment: - user_id = experiment.user_id - else: - raise ValueError(f"Experiment with id {experiment_id} not found") - - if user_id is None: - raise ValueError("User ID must be provided or derivable from experiment") - - draw_datetime_utc: datetime = datetime.now(timezone.utc) - - draw = BayesianABDrawDB( - draw_id=draw_id, - client_id=client_id, - experiment_id=experiment_id, - user_id=user_id, - arm_id=arm_id, - draw_datetime_utc=draw_datetime_utc, - ) - - asession.add(draw) - await asession.commit() - await asession.refresh(draw) - - return draw - - -async def get_bayes_ab_obs_by_experiment_arm_id( - experiment_id: int, - arm_id: int, - asession: AsyncSession, -) -> Sequence[BayesianABDrawDB]: - """ - Get the observations of a specific arm in an A/B experiment. - """ - stmt = ( - select(BayesianABDrawDB) - .where( - and_( - BayesianABDrawDB.experiment_id == experiment_id, - BayesianABDrawDB.arm_id == arm_id, - BayesianABDrawDB.reward.is_not(None), - ) - ) - .order_by(BayesianABDrawDB.observed_datetime_utc) - ) - - result = await asession.execute(stmt) - return result.unique().scalars().all() - - -async def get_bayes_ab_obs_by_experiment_id( - experiment_id: int, - workspace_id: int, - asession: AsyncSession, -) -> Sequence[BayesianABDrawDB]: - """ - Get the observations of the A/B experiment. - Verified to belong to the specified workspace. - """ - # First, verify experiment belongs to the workspace - experiment = await get_bayes_ab_experiment_by_id( - experiment_id=experiment_id, - workspace_id=workspace_id, - asession=asession, - ) - - if experiment is None: - # Return empty list if experiment doesn't exist or doesn't belong to workspace - return [] - - # Get observations for this experiment - stmt = ( - select(BayesianABDrawDB) - .where( - and_( - BayesianABDrawDB.experiment_id == experiment_id, - BayesianABDrawDB.reward.is_not(None), - ) - ) - .order_by(BayesianABDrawDB.observed_datetime_utc) - ) - - result = await asession.execute(stmt) - return result.unique().scalars().all() - - -async def get_bayes_ab_draw_by_id( - draw_id: str, asession: AsyncSession -) -> BayesianABDrawDB | None: - """ - Get a draw by its ID - """ - statement = select(BayesianABDrawDB).where(BayesianABDrawDB.draw_id == draw_id) - result = await asession.execute(statement) - - return result.unique().scalar_one_or_none() - - -async def get_bayes_ab_draw_by_client_id( - client_id: str, experiment_id: int, asession: AsyncSession -) -> BayesianABDrawDB | None: - """ - Get a draw by its client ID for a specific experiment. - """ - statement = select(BayesianABDrawDB).where( - and_( - BayesianABDrawDB.client_id == client_id, - BayesianABDrawDB.client_id.is_not(None), - BayesianABDrawDB.experiment_id == experiment_id, - ) - ) - result = await asession.execute(statement) - - return result.unique().scalars().first() diff --git a/backend/app/bayes_ab/observation.py b/backend/app/bayes_ab/observation.py deleted file mode 100644 index 212dc91..0000000 --- a/backend/app/bayes_ab/observation.py +++ /dev/null @@ -1,75 +0,0 @@ -from datetime import datetime, timezone - -from fastapi.exceptions import HTTPException -from sqlalchemy.ext.asyncio import AsyncSession - -from ..schemas import ObservationType, Outcome, RewardLikelihood -from .models import ( - BayesianABArmDB, - BayesianABDB, - BayesianABDrawDB, - save_bayes_ab_observation_to_db, -) -from .schemas import ( - BayesABArmResponse, - BayesianABSample, -) - - -async def update_based_on_outcome( - experiment: BayesianABDB, - draw: BayesianABDrawDB, - outcome: float, - asession: AsyncSession, - observation: ObservationType, -) -> BayesABArmResponse: - """ - Update the arm parameters based on the outcome. - - This is a helper function to allow `auto_fail` job to call - it as well. - """ - update_experiment_metadata(experiment) - - arm = get_arm_from_experiment(experiment, draw.arm_id) - arm.n_outcomes += 1 - - experiment_data = BayesianABSample.model_validate(experiment) - if experiment_data.reward_type == RewardLikelihood.BERNOULLI: - Outcome(outcome) # Check if reward is 0 or 1 - - await save_updated_data(arm, draw, outcome, asession) - - return BayesABArmResponse.model_validate(arm) - - -def update_experiment_metadata(experiment: BayesianABDB) -> None: - """ - Update the experiment metadata with new information. - """ - experiment.n_trials += 1 - experiment.last_trial_datetime_utc = datetime.now(tz=timezone.utc) - - -def get_arm_from_experiment(experiment: BayesianABDB, arm_id: int) -> BayesianABArmDB: - """ - Get and validate the arm from the experiment. - """ - arms = [a for a in experiment.arms if a.arm_id == arm_id] - if not arms: - raise HTTPException(status_code=404, detail=f"Arm with id {arm_id} not found") - return arms[0] - - -async def save_updated_data( - arm: BayesianABArmDB, - draw: BayesianABDrawDB, - outcome: float, - asession: AsyncSession, -) -> None: - """ - Save the updated data to the database. - """ - asession.add(arm) - await asession.commit() - await save_bayes_ab_observation_to_db(draw, outcome, asession) diff --git a/backend/app/bayes_ab/routers.py b/backend/app/bayes_ab/routers.py deleted file mode 100644 index c1041a2..0000000 --- a/backend/app/bayes_ab/routers.py +++ /dev/null @@ -1,524 +0,0 @@ -from typing import Annotated, Optional -from uuid import uuid4 - -from fastapi import APIRouter, Depends -from fastapi.exceptions import HTTPException -from sqlalchemy.ext.asyncio import AsyncSession - -from ..auth.dependencies import ( - authenticate_workspace_key, - get_verified_user, - require_admin_role, -) -from ..database import get_async_session -from ..models import get_notifications_from_db, save_notifications_to_db -from ..schemas import NotificationsResponse, ObservationType -from ..users.models import UserDB -from ..workspaces.models import ( - WorkspaceDB, - get_user_default_workspace, -) -from .models import ( - BayesianABDB, - BayesianABDrawDB, - delete_bayes_ab_experiment_by_id, - get_all_bayes_ab_experiments, - get_bayes_ab_draw_by_client_id, - get_bayes_ab_draw_by_id, - get_bayes_ab_experiment_by_id, - get_bayes_ab_obs_by_experiment_id, - save_bayes_ab_draw_to_db, - save_bayes_ab_to_db, -) -from .observation import update_based_on_outcome -from .sampling_utils import choose_arm, update_arm_params -from .schemas import ( - BayesABArmResponse, - BayesianAB, - BayesianABDrawResponse, - BayesianABObservationResponse, - BayesianABResponse, - BayesianABSample, -) - -router = APIRouter(prefix="/bayes_ab", tags=["Bayesian A/B Testing"]) - - -@router.post("/", response_model=BayesianABResponse) -async def create_ab_experiment( - experiment: BayesianAB, - user_db: Annotated[UserDB, Depends(require_admin_role)], - asession: AsyncSession = Depends(get_async_session), -) -> BayesianABResponse: - """ - Create a new experiment in the user's current workspace. - """ - workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - - if workspace_db is None: - raise HTTPException( - status_code=404, - detail="Workspace not found for the user.", - ) - - bayes_ab = await save_bayes_ab_to_db( - experiment, user_db.user_id, workspace_db.workspace_id, asession - ) - - notifications = await save_notifications_to_db( - experiment_id=bayes_ab.experiment_id, - user_id=user_db.user_id, - notifications=experiment.notifications, - asession=asession, - ) - - bayes_ab_dict = bayes_ab.to_dict() - bayes_ab_dict["notifications"] = [n.to_dict() for n in notifications] - - return BayesianABResponse.model_validate(bayes_ab_dict) - - -@router.get("/", response_model=list[BayesianABResponse]) -async def get_bayes_abs( - user_db: Annotated[UserDB, Depends(get_verified_user)], - asession: AsyncSession = Depends(get_async_session), -) -> list[BayesianABResponse]: - """ - Get details of all experiments in the user's current workspace. - """ - workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - - if workspace_db is None: - raise HTTPException( - status_code=404, - detail="Workspace not found for the user.", - ) - - experiments = await get_all_bayes_ab_experiments( - workspace_db.workspace_id, asession - ) - - all_experiments = [] - for exp in experiments: - exp_dict = exp.to_dict() - exp_dict["notifications"] = [ - n.to_dict() - for n in await get_notifications_from_db( - exp.experiment_id, exp.user_id, asession - ) - ] - all_experiments.append( - BayesianABResponse.model_validate( - { - **exp_dict, - "notifications": [ - NotificationsResponse(**n) for n in exp_dict["notifications"] - ], - } - ) - ) - return all_experiments - - -@router.get("/{experiment_id}", response_model=BayesianABResponse) -async def get_bayes_ab( - experiment_id: int, - user_db: Annotated[UserDB, Depends(get_verified_user)], - asession: AsyncSession = Depends(get_async_session), -) -> BayesianABResponse: - """ - Get details of experiment with the provided `experiment_id`. - """ - workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - - if workspace_db is None: - raise HTTPException( - status_code=404, - detail="Workspace not found for the user.", - ) - - experiment = await get_bayes_ab_experiment_by_id( - experiment_id, workspace_db.workspace_id, asession - ) - - if experiment is None: - raise HTTPException( - status_code=404, detail=f"Experiment with id {experiment_id} not found" - ) - - experiment_dict = experiment.to_dict() - experiment_dict["notifications"] = [ - n.to_dict() - for n in await get_notifications_from_db( - experiment.experiment_id, experiment.user_id, asession - ) - ] - - return BayesianABResponse.model_validate(experiment_dict) - - -@router.delete("/{experiment_id}", response_model=dict) -async def delete_bayes_ab( - experiment_id: int, - user_db: Annotated[UserDB, Depends(require_admin_role)], - asession: AsyncSession = Depends(get_async_session), -) -> dict: - """ - Delete the experiment with the provided `experiment_id`. - """ - try: - workspace_db = await get_user_default_workspace( - asession=asession, user_db=user_db - ) - - if workspace_db is None: - raise HTTPException( - status_code=404, - detail="Workspace not found for the user.", - ) - - experiment = await get_bayes_ab_experiment_by_id( - experiment_id, workspace_db.workspace_id, asession - ) - if experiment is None: - raise HTTPException( - status_code=404, detail=f"Experiment with id {experiment_id} not found" - ) - - await delete_bayes_ab_experiment_by_id( - experiment_id, workspace_db.workspace_id, asession - ) - - return {"message": f"Experiment with id {experiment_id} deleted successfully."} - except Exception as e: - raise HTTPException(status_code=500, detail=f"Error: {e}") from e - - -@router.get("/{experiment_id}/draw", response_model=BayesianABDrawResponse) -async def draw_arm( - experiment_id: int, - draw_id: Optional[str] = None, - client_id: Optional[str] = None, - workspace_db: WorkspaceDB = Depends(authenticate_workspace_key), - asession: AsyncSession = Depends(get_async_session), -) -> BayesianABDrawResponse: - """ - Get which arm to pull next for provided experiment. - """ - # Get workspace from user context - workspace_id = workspace_db.workspace_id - - experiment = await get_bayes_ab_experiment_by_id( - experiment_id, workspace_id, asession - ) - - if experiment is None: - raise HTTPException( - status_code=404, detail=f"Experiment with id {experiment_id} not found" - ) - - if experiment.sticky_assignment and not client_id: - raise HTTPException( - status_code=400, - detail="Client ID is required for sticky assignment.", - ) - - experiment_data = BayesianABSample.model_validate(experiment) - chosen_arm = choose_arm(experiment=experiment_data) - chosen_arm_id = experiment.arms[chosen_arm].arm_id - if experiment.sticky_assignment and client_id: - # Check if the client_id is already assigned to an arm - previous_draw = await get_bayes_ab_draw_by_client_id( - client_id=client_id, - experiment_id=experiment_id, - asession=asession, - ) - if previous_draw: - chosen_arm_id = previous_draw.arm_id - - # Check for existing draws - if draw_id is None: - draw_id = str(uuid4()) - - existing_draw = await get_bayes_ab_draw_by_id(draw_id=draw_id, asession=asession) - if existing_draw: - raise HTTPException( - status_code=400, - detail=f"Draw with id {draw_id} already exists for \ - experiment {experiment_id}", - ) - - try: - await save_bayes_ab_draw_to_db( - experiment_id=experiment.experiment_id, - arm_id=chosen_arm_id, - draw_id=draw_id, - client_id=client_id, - user_id=None, - asession=asession, - workspace_id=workspace_id, - ) - except Exception as e: - raise HTTPException( - status_code=500, - detail=f"Error saving draw to database: {e}", - ) from e - - return BayesianABDrawResponse.model_validate( - { - "draw_id": draw_id, - "client_id": client_id, - "arm": BayesABArmResponse.model_validate( - [arm for arm in experiment.arms if arm.arm_id == chosen_arm_id][0], - ), - } - ) - - -@router.put("/{experiment_id}/{draw_id}/{outcome}", response_model=BayesABArmResponse) -async def save_observation_for_arm( - experiment_id: int, - draw_id: str, - outcome: float, - workspace_db: WorkspaceDB = Depends(authenticate_workspace_key), - asession: AsyncSession = Depends(get_async_session), -) -> BayesABArmResponse: - """ - Update the arm with the provided `arm_id` for the given - `experiment_id` based on the `outcome`. - """ - # Get workspace from user context - workspace_id = workspace_db.workspace_id - - # Get and validate experiment - experiment, draw = await validate_experiment_and_draw( - experiment_id=experiment_id, - draw_id=draw_id, - workspace_id=workspace_id, - asession=asession, - ) - - return await update_based_on_outcome( - experiment=experiment, - draw=draw, - outcome=outcome, - asession=asession, - observation=ObservationType.USER, - ) - - -@router.get( - "/{experiment_id}/outcomes", - response_model=list[BayesianABObservationResponse], -) -async def get_outcomes( - experiment_id: int, - workspace_db: WorkspaceDB = Depends(authenticate_workspace_key), - asession: AsyncSession = Depends(get_async_session), -) -> list[BayesianABObservationResponse]: - """ - Get the outcomes for the experiment. - """ - # Get workspace from user context - workspace_id = workspace_db.workspace_id - - experiment = await get_bayes_ab_experiment_by_id( - experiment_id, workspace_id, asession - ) - if not experiment: - raise HTTPException( - status_code=404, detail=f"Experiment with id {experiment_id} not found" - ) - - rewards = await get_bayes_ab_obs_by_experiment_id( - experiment_id=experiment.experiment_id, - workspace_id=workspace_id, - asession=asession, - ) - - return [BayesianABObservationResponse.model_validate(reward) for reward in rewards] - - -@router.get( - "/{experiment_id}/arms", - response_model=list[BayesABArmResponse], -) -async def update_arms( - experiment_id: int, - workspace_db: WorkspaceDB = Depends(authenticate_workspace_key), - asession: AsyncSession = Depends(get_async_session), -) -> list[BayesABArmResponse]: - """ - Get the outcomes for the experiment. - """ - # Get workspace from user context - workspace_id = workspace_db.workspace_id - - # Check experiment params - experiment = await get_bayes_ab_experiment_by_id( - experiment_id, workspace_id, asession - ) - if not experiment: - raise HTTPException( - status_code=404, detail=f"Experiment with id {experiment_id} not found" - ) - - # Prepare data for arms update - ( - rewards, - treatments, - treatment_mu, - treatment_sigma, - control_mu, - control_sigma, - ) = await prepare_data_for_arms_update( - experiment=experiment, - workspace_id=workspace_id, - asession=asession, - ) - - # Make updates - arms_data = await make_updates_to_arms( - experiment=experiment, - treatment_mu=treatment_mu, - treatment_sigma=treatment_sigma, - control_mu=control_mu, - control_sigma=control_sigma, - rewards=rewards, - treatments=treatments, - asession=asession, - ) - - return arms_data - - -# ---- Helper functions ---- - - -async def validate_experiment_and_draw( - experiment_id: int, - draw_id: str, - workspace_id: int, - asession: AsyncSession, -) -> tuple[BayesianABDB, BayesianABDrawDB]: - """Validate the experiment and draw""" - experiment = await get_bayes_ab_experiment_by_id( - experiment_id, workspace_id, asession - ) - if experiment is None: - raise HTTPException( - status_code=404, detail=f"Experiment with id {experiment_id} not found" - ) - - draw = await get_bayes_ab_draw_by_id(draw_id=draw_id, asession=asession) - if draw is None: - raise HTTPException(status_code=404, detail=f"Draw with id {draw_id} not found") - - if draw.experiment_id != experiment_id: - raise HTTPException( - status_code=400, - detail=( - f"Draw with id {draw_id} does not belong " - f"to experiment with id {experiment_id}", - ), - ) - - if draw.reward is not None: - raise HTTPException( - status_code=400, - detail=f"Draw with id {draw_id} already has an outcome.", - ) - - return experiment, draw - - -async def prepare_data_for_arms_update( - experiment: BayesianABDB, - workspace_id: int, - asession: AsyncSession, -) -> tuple[list[float], list[float], float, float, float, float]: - """ - Prepare the data for arm update. - """ - # Get observations - observations = await get_bayes_ab_obs_by_experiment_id( - experiment_id=experiment.experiment_id, - workspace_id=workspace_id, - asession=asession, - ) - - if not observations: - raise HTTPException( - status_code=404, - detail=f"No observations found for experiment {experiment.experiment_id}", - ) - - rewards = [obs.reward for obs in observations] - - # Get treatment and control arms - arms_dict = { - arm.arm_id: 1.0 if arm.is_treatment_arm else 0.0 for arm in experiment.arms - } - - # Get params - treatment_mu, treatment_sigma = [ - (arm.mu_init, arm.sigma_init) for arm in experiment.arms if arm.is_treatment_arm - ][0] - control_mu, control_sigma = [ - (arm.mu_init, arm.sigma_init) - for arm in experiment.arms - if not arm.is_treatment_arm - ][0] - - treatments = [arms_dict[obs.arm_id] for obs in observations] - - return ( - rewards, - treatments, - treatment_mu, - treatment_sigma, - control_mu, - control_sigma, - ) - - -async def make_updates_to_arms( - experiment: BayesianABDB, - treatment_mu: float, - treatment_sigma: float, - control_mu: float, - control_sigma: float, - rewards: list[float], - treatments: list[float], - asession: AsyncSession, -) -> list[BayesABArmResponse]: - """ - Make updates to the arms of the experiment. - """ - # Make updates - experiment_data = BayesianABSample.model_validate(experiment) - new_means, new_sigmas = update_arm_params( - experiment=experiment_data, - mus=[treatment_mu, control_mu], - sigmas=[treatment_sigma, control_sigma], - rewards=rewards, - treatments=treatments, - ) - - arms_data = [] - for arm in experiment.arms: - if arm.is_treatment_arm: - arm.mu = new_means[0] - arm.sigma = new_sigmas[0] - else: - arm.mu = new_means[1] - arm.sigma = new_sigmas[1] - - asession.add(arm) - arms_data.append(BayesABArmResponse.model_validate(arm)) - - asession.add(experiment) - - await asession.commit() - - return arms_data diff --git a/backend/app/bayes_ab/sampling_utils.py b/backend/app/bayes_ab/sampling_utils.py deleted file mode 100644 index 0416f64..0000000 --- a/backend/app/bayes_ab/sampling_utils.py +++ /dev/null @@ -1,126 +0,0 @@ -import numpy as np -from scipy.optimize import minimize - -from ..schemas import ArmPriors, ContextLinkFunctions, RewardLikelihood -from .schemas import BayesianABSample - - -def _update_arms( - mus: np.ndarray, - sigmas: np.ndarray, - rewards: np.ndarray, - treatments: np.ndarray, - link_function: ContextLinkFunctions, - reward_likelihood: RewardLikelihood, - prior_type: ArmPriors, -) -> tuple[list, list]: - """ - Get arm posteriors. - - Parameters - ---------- - mu : np.ndarray - The mean of the Normal distribution. - sigma : np.ndarray - The standard deviation of the Normal distribution. - rewards : np.ndarray - The rewards. - treatments : np.ndarray - The treatments (binary-valued). - link_function : ContextLinkFunctions - The link function for parameters to rewards. - reward_likelihood : RewardLikelihood - The likelihood function of the reward. - prior_type : ArmPriors - The prior type of the arm. - """ - - # TODO we explicitly assume that there is only 1 treatment arm - def objective(treatment_effect_arms_bias: np.ndarray) -> float: - """ - Objective function for arm to outcome. - - Parameters - ---------- - treatment_effect : float - The treatment effect. - """ - treatment, control, bias = treatment_effect_arms_bias - - # log prior - log_prior = prior_type( - np.array([treatment, control]), mu=mus, covariance=np.diag(sigmas) - ) - - # log likelihood - log_likelihood = reward_likelihood( - rewards, - link_function(treatment * treatments + control * (1 - treatments) + bias), - ) - return -(log_prior + log_likelihood) - - result = minimize(objective, x0=np.zeros(3), method="L-BFGS-B", hess="2-point") - new_treatment_mean, new_control_mean, _ = result.x - new_treatment_sigma, new_control_sigma, _ = np.sqrt( - np.diag(result.hess_inv.todense()) # type: ignore - ) - return [new_treatment_mean, new_control_mean], [ - new_treatment_sigma, - new_control_sigma, - ] - - -def choose_arm(experiment: BayesianABSample) -> int: - """ - Choose arm based on posterior - - Parameters - ---------- - experiment : BayesianABSample - The experiment data containing priors and rewards for each arm. - """ - index = np.random.choice(len(experiment.arms), size=1) - return int(index[0]) - - -def update_arm_params( - experiment: BayesianABSample, - mus: list[float], - sigmas: list[float], - rewards: list[float], - treatments: list[float], -) -> tuple[list, list]: - """ - Update the arm parameters based on the reward type. - - Parameters - ---------- - experiment : BayesianABSample - The experiment data containing arms, prior type and reward - type information. - mus : list[float] - The means of the arms. - sigmas : list[float] - The standard deviations of the arms. - rewards : list[float] - The rewards. - treatments : list[float] - Which arm was applied corresponding to the reward. - """ - link_function = None - if experiment.reward_type == RewardLikelihood.NORMAL: - link_function = ContextLinkFunctions.NONE - elif experiment.reward_type == RewardLikelihood.BERNOULLI: - link_function = ContextLinkFunctions.LOGISTIC - else: - raise ValueError("Invalid reward type") - - return _update_arms( - mus=np.array(mus), - sigmas=np.array(sigmas), - rewards=np.array(rewards), - treatments=np.array(treatments), - link_function=link_function, - reward_likelihood=experiment.reward_type, - prior_type=experiment.prior_type, - ) diff --git a/backend/app/bayes_ab/schemas.py b/backend/app/bayes_ab/schemas.py deleted file mode 100644 index ef55d07..0000000 --- a/backend/app/bayes_ab/schemas.py +++ /dev/null @@ -1,145 +0,0 @@ -from datetime import datetime -from typing import Optional, Self - -from pydantic import BaseModel, ConfigDict, Field, model_validator - -from ..mab.schemas import ( - MABObservationResponse, - MultiArmedBanditBase, -) -from ..schemas import Notifications, NotificationsResponse, allowed_combos_bayes_ab - - -class BayesABArm(BaseModel): - """ - Pydantic model for a arm of the experiment. - """ - - name: str = Field( - max_length=150, - examples=["Arm 1"], - ) - description: str = Field( - max_length=500, - examples=["This is a description of the arm."], - ) - - mu_init: float = Field( - default=0.0, description="Mean parameter for treatment effect prior" - ) - sigma_init: float = Field( - default=1.0, description="Std dev parameter for treatment effect prior" - ) - n_outcomes: Optional[int] = Field( - default=0, - description="Number of outcomes for the arm", - examples=[0, 10, 15], - ) - is_treatment_arm: bool = Field( - default=True, - description="Is the arm a treatment arm", - examples=[True, False], - ) - - @model_validator(mode="after") - def check_values(self) -> Self: - """ - Check if the values are unique and set new attributes. - """ - if self.sigma_init is not None and self.sigma_init <= 0: - raise ValueError("Std dev must be greater than 0.") - return self - - model_config = ConfigDict(from_attributes=True) - - -class BayesABArmResponse(BayesABArm): - """ - Pydantic model for a response for contextual arm creation - """ - - arm_id: int - mu: float - sigma: float - model_config = ConfigDict(from_attributes=True) - - -class BayesianAB(MultiArmedBanditBase): - """ - Pydantic model for an A/B experiment. - """ - - arms: list[BayesABArm] - notifications: Notifications - model_config = ConfigDict(from_attributes=True) - - @model_validator(mode="after") - def arms_exactly_two(self) -> Self: - """ - Validate that the experiment has exactly two arms. - """ - if len(self.arms) != 2: - raise ValueError("The experiment must have at exactly two arms.") - return self - - @model_validator(mode="after") - def check_prior_reward_type_combo(self) -> Self: - """ - Validate that the prior and reward type combination is allowed. - """ - - if (self.prior_type, self.reward_type) not in allowed_combos_bayes_ab: - raise ValueError("Prior and reward type combo not supported.") - return self - - @model_validator(mode="after") - def check_treatment_and_control_arm(self) -> Self: - """ - Validate that the experiment has at least one control arm. - """ - if sum(arm.is_treatment_arm for arm in self.arms) != 1: - raise ValueError("The experiment must have one treatment and control arm.") - return self - - -class BayesianABResponse(MultiArmedBanditBase): - """ - Pydantic model for a response for an A/B experiment. - """ - - experiment_id: int - workspace_id: int - arms: list[BayesABArmResponse] - notifications: list[NotificationsResponse] - created_datetime_utc: datetime - last_trial_datetime_utc: Optional[datetime] = None - n_trials: int - - model_config = ConfigDict(from_attributes=True) - - -class BayesianABSample(MultiArmedBanditBase): - """ - Pydantic model for a sample A/B experiment. - """ - - experiment_id: int - arms: list[BayesABArmResponse] - - -class BayesianABObservationResponse(MABObservationResponse): - """ - Pydantic model for an observation response in an A/B experiment. - """ - - pass - - -class BayesianABDrawResponse(BaseModel): - """ - Pydantic model for a draw response in an A/B experiment. - """ - - draw_id: str - client_id: str | None - arm: BayesABArmResponse diff --git a/backend/app/contextual_mab/__init__.py b/backend/app/contextual_mab/__init__.py deleted file mode 100644 index fa07d07..0000000 --- a/backend/app/contextual_mab/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .routers import router # noqa: F401 diff --git a/backend/app/contextual_mab/models.py b/backend/app/contextual_mab/models.py deleted file mode 100644 index 60cf723..0000000 --- a/backend/app/contextual_mab/models.py +++ /dev/null @@ -1,483 +0,0 @@ -from datetime import datetime, timezone -from typing import Sequence - -import numpy as np -from sqlalchemy import ( - Float, - ForeignKey, - Integer, - String, - and_, - delete, - select, -) -from sqlalchemy.dialects.postgresql import ARRAY -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import Mapped, mapped_column, relationship - -from ..models import ( - ArmBaseDB, - Base, - DrawsBaseDB, - ExperimentBaseDB, - NotificationsDB, -) -from ..schemas import ObservationType -from .schemas import ContextualBandit - - -class ContextualBanditDB(ExperimentBaseDB): - """ - ORM for managing contextual experiments. - """ - - __tablename__ = "contextual_mabs" - - experiment_id: Mapped[int] = mapped_column( - ForeignKey("experiments_base.experiment_id", ondelete="CASCADE"), - primary_key=True, - nullable=False, - ) - - arms: Mapped[list["ContextualArmDB"]] = relationship( - "ContextualArmDB", back_populates="experiment", lazy="joined" - ) - - contexts: Mapped[list["ContextDB"]] = relationship( - "ContextDB", back_populates="experiment", lazy="joined" - ) - - draws: Mapped[list["ContextualDrawDB"]] = relationship( - "ContextualDrawDB", back_populates="experiment", lazy="joined" - ) - - __mapper_args__ = {"polymorphic_identity": "contextual_mabs"} - - def to_dict(self) -> dict: - """ - Convert the ORM object to a dictionary. - """ - return { - "experiment_id": self.experiment_id, - "user_id": self.user_id, - "workspace_id": self.workspace_id, - "name": self.name, - "description": self.description, - "sticky_assignment": self.sticky_assignment, - "auto_fail": self.auto_fail, - "auto_fail_value": self.auto_fail_value, - "auto_fail_unit": self.auto_fail_unit, - "created_datetime_utc": self.created_datetime_utc, - "is_active": self.is_active, - "n_trials": self.n_trials, - "arms": [arm.to_dict() for arm in self.arms], - "contexts": [context.to_dict() for context in self.contexts], - "prior_type": self.prior_type, - "reward_type": self.reward_type, - } - - -class ContextualArmDB(ArmBaseDB): - """ - ORM for managing contextual arms of an experiment - """ - - __tablename__ = "contextual_arms" - - arm_id: Mapped[int] = mapped_column( - ForeignKey("arms_base.arm_id", ondelete="CASCADE"), - primary_key=True, - nullable=False, - ) - - # prior variables for CMAB arms - mu_init: Mapped[float] = mapped_column(Float, nullable=False) - sigma_init: Mapped[float] = mapped_column(Float, nullable=False) - mu: Mapped[list[float]] = mapped_column(ARRAY(Float), nullable=False) - covariance: Mapped[list[float]] = mapped_column(ARRAY(Float), nullable=False) - - experiment: Mapped[ContextualBanditDB] = relationship( - "ContextualBanditDB", back_populates="arms", lazy="joined" - ) - draws: Mapped[list["ContextualDrawDB"]] = relationship( - "ContextualDrawDB", back_populates="arm", lazy="joined" - ) - - __mapper_args__ = {"polymorphic_identity": "contextual_arms"} - - def to_dict(self) -> dict: - """ - Convert the ORM object to a dictionary. - """ - return { - "arm_id": self.arm_id, - "name": self.name, - "description": self.description, - "mu_init": self.mu_init, - "sigma_init": self.sigma_init, - "mu": self.mu, - "covariance": self.covariance, - "draws": [draw.to_dict() for draw in self.draws], - } - - -class ContextDB(Base): - """ - ORM for managing context for an experiment - """ - - __tablename__ = "contexts" - - context_id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False) - experiment_id: Mapped[int] = mapped_column( - Integer, ForeignKey("contextual_mabs.experiment_id"), nullable=False - ) - user_id: Mapped[int] = mapped_column( - Integer, ForeignKey("users.user_id"), nullable=False - ) - name: Mapped[str] = mapped_column(String(length=150), nullable=False) - description: Mapped[str] = mapped_column(String(length=500), nullable=True) - value_type: Mapped[str] = mapped_column(String(length=50), nullable=False) - - experiment: Mapped[ContextualBanditDB] = relationship( - "ContextualBanditDB", back_populates="contexts", lazy="joined" - ) - - def to_dict(self) -> dict: - """ - Convert the ORM object to a dictionary. - """ - return { - "context_id": self.context_id, - "name": self.name, - "description": self.description, - "value_type": self.value_type, - } - - -class ContextualDrawDB(DrawsBaseDB): - """ - ORM for managing draws of an experiment - """ - - __tablename__ = "contextual_draws" - - draw_id: Mapped[str] = mapped_column( - ForeignKey("draws_base.draw_id", ondelete="CASCADE"), - primary_key=True, - nullable=False, - ) - - context_val: Mapped[list] = mapped_column(ARRAY(Float), nullable=False) - arm: Mapped[ContextualArmDB] = relationship( - "ContextualArmDB", back_populates="draws", lazy="joined" - ) - experiment: Mapped[ContextualBanditDB] = relationship( - "ContextualBanditDB", back_populates="draws", lazy="joined" - ) - - __mapper_args__ = {"polymorphic_identity": "contextual_draws"} - - def to_dict(self) -> dict: - """ - Convert the ORM object to a dictionary. - """ - return { - "draw_id": self.draw_id, - "client_id": self.client_id, - "draw_datetime_utc": self.draw_datetime_utc, - "context_val": self.context_val, - "arm_id": self.arm_id, - "experiment_id": self.experiment_id, - "user_id": self.user_id, - "reward": self.reward, - "observation_type": self.observation_type, - "observed_datetime_utc": self.observed_datetime_utc, - } - - -async def save_contextual_mab_to_db( - experiment: ContextualBandit, - user_id: int, - workspace_id: int, - asession: AsyncSession, -) -> ContextualBanditDB: - """ - Save the experiment to the database. - """ - contexts = [ - ContextDB( - name=context.name, - description=context.description, - value_type=context.value_type.value, - user_id=user_id, - ) - for context in experiment.contexts - ] - arms = [] - for arm in experiment.arms: - arms.append( - ContextualArmDB( - name=arm.name, - description=arm.description, - mu_init=arm.mu_init, - sigma_init=arm.sigma_init, - mu=(np.ones(len(experiment.contexts)) * arm.mu_init).tolist(), - covariance=( - np.identity(len(experiment.contexts)) * arm.sigma_init - ).tolist(), - user_id=user_id, - n_outcomes=arm.n_outcomes, - ) - ) - - experiment_db = ContextualBanditDB( - name=experiment.name, - description=experiment.description, - user_id=user_id, - workspace_id=workspace_id, - is_active=experiment.is_active, - created_datetime_utc=datetime.now(timezone.utc), - n_trials=0, - arms=arms, - sticky_assignment=experiment.sticky_assignment, - auto_fail=experiment.auto_fail, - auto_fail_value=experiment.auto_fail_value, - auto_fail_unit=experiment.auto_fail_unit, - contexts=contexts, - prior_type=experiment.prior_type.value, - reward_type=experiment.reward_type.value, - ) - - asession.add(experiment_db) - await asession.commit() - await asession.refresh(experiment_db) - - return experiment_db - - -async def get_all_contextual_mabs( - workspace_id: int, - asession: AsyncSession, -) -> Sequence[ContextualBanditDB]: - """ - Get all the contextual experiments from the database for a specific workspace. - """ - statement = ( - select(ContextualBanditDB) - .where(ContextualBanditDB.workspace_id == workspace_id) - .order_by(ContextualBanditDB.experiment_id) - ) - - return (await asession.execute(statement)).unique().scalars().all() - - -async def get_contextual_mab_by_id( - experiment_id: int, workspace_id: int, asession: AsyncSession -) -> ContextualBanditDB | None: - """ - Get the contextual experiment by id from a specific workspace. - """ - condition = [ - ContextualBanditDB.experiment_id == experiment_id, - ContextualBanditDB.workspace_id == workspace_id, - ] - - statement = select(ContextualBanditDB).where(*condition) - result = await asession.execute(statement) - - return result.unique().scalar_one_or_none() - - -async def delete_contextual_mab_by_id( - experiment_id: int, workspace_id: int, asession: AsyncSession -) -> None: - """ - Delete the contextual experiment by id. - """ - await asession.execute( - delete(NotificationsDB).where(NotificationsDB.experiment_id == experiment_id) - ) - - await asession.execute( - delete(ContextualDrawDB).where(ContextualDrawDB.experiment_id == experiment_id) - ) - - await asession.execute( - delete(ContextDB).where(ContextDB.experiment_id == experiment_id) - ) - - await asession.execute( - delete(ContextualArmDB).where(ContextualArmDB.experiment_id == experiment_id) - ) - - await asession.execute( - delete(ContextualBanditDB).where( - and_( - ContextualBanditDB.workspace_id == workspace_id, - ContextualBanditDB.experiment_id == experiment_id, - ContextualBanditDB.experiment_id == ExperimentBaseDB.experiment_id, - ) - ) - ) - await asession.commit() - return None - - -async def save_contextual_obs_to_db( - draw: ContextualDrawDB, - reward: float, - asession: AsyncSession, - observation_type: ObservationType, -) -> ContextualDrawDB: - """ - Save the observation to the database. - """ - draw.reward = reward - draw.observed_datetime_utc = datetime.now(timezone.utc) - draw.observation_type = observation_type # Remove .value, pass enum directly - - await asession.commit() - await asession.refresh(draw) - - return draw - - -async def get_contextual_obs_by_experiment_arm_id( - experiment_id: int, - arm_id: int, - asession: AsyncSession, -) -> Sequence[ContextualDrawDB]: - """Get the observations for a specific arm of an experiment.""" - statement = ( - select(ContextualDrawDB) - .where( - and_( - ContextualDrawDB.experiment_id == experiment_id, - ContextualDrawDB.arm_id == arm_id, - ContextualDrawDB.reward.is_not(None), - ) - ) - .order_by(ContextualDrawDB.observed_datetime_utc) - ) - - result = await asession.execute(statement) - return result.unique().scalars().all() - - -async def get_all_contextual_obs_by_experiment_id( - experiment_id: int, - workspace_id: int, - asession: AsyncSession, -) -> Sequence[ContextualDrawDB]: - """ - Get all observations for an experiment, - verified to belong to the specified workspace. - """ - # First, verify experiment belongs to the workspace - experiment = await get_contextual_mab_by_id( - experiment_id=experiment_id, - workspace_id=workspace_id, - asession=asession, - ) - - if experiment is None: - # Return empty list if experiment doesn't exist or doesn't belong to workspace - return [] - - # Get all observations for this experiment - statement = ( - select(ContextualDrawDB) - .where( - and_( - ContextualDrawDB.experiment_id == experiment_id, - ContextualDrawDB.reward.is_not(None), - ) - ) - .order_by(ContextualDrawDB.observed_datetime_utc) - ) - - result = await asession.execute(statement) - return result.unique().scalars().all() - - -async def get_draw_by_id( - draw_id: str, asession: AsyncSession -) -> ContextualDrawDB | None: - """ - Get the draw by its ID, which should be unique across the system. - """ - statement = select(ContextualDrawDB).where(ContextualDrawDB.draw_id == draw_id) - result = await asession.execute(statement) - return result.unique().scalar_one_or_none() - - -async def get_draw_by_client_id( - client_id: str, experiment_id: int, asession: AsyncSession -) -> ContextualDrawDB | None: - """ - Get the draw by client id for a specific experiment. - """ - statement = ( - select(ContextualDrawDB) - .where(ContextualDrawDB.client_id == client_id) - .where(ContextualDrawDB.client_id.is_not(None)) - .where(ContextualDrawDB.experiment_id == experiment_id) - ) - result = await asession.execute(statement) - - return result.unique().scalars().first() - - -async def save_draw_to_db( - experiment_id: int, - arm_id: int, - context_val: list[float], - draw_id: str, - client_id: str | None, - user_id: int | None, - asession: AsyncSession, - workspace_id: int | None, -) -> ContextualDrawDB: - """ - Save the draw to the database. - """ - # If user_id is not provided but needed, get it from the experiment - if user_id is None: - if workspace_id is not None: - # Try to get experiment with workspace_id - experiment = await get_contextual_mab_by_id( - experiment_id=experiment_id, - workspace_id=workspace_id, - asession=asession, - ) - if experiment: - user_id = experiment.user_id - else: - raise ValueError(f"Experiment with id {experiment_id} not found") - else: - # Fall back to direct get if workspace_id not provided - experiment = await asession.get(ContextualBanditDB, experiment_id) - if experiment: - user_id = experiment.user_id - else: - raise ValueError(f"Experiment with id {experiment_id} not found") - - if user_id is None: - raise ValueError("User ID must be provided or derivable from experiment") - - draw_db = ContextualDrawDB( - draw_id=draw_id, - client_id=client_id, - arm_id=arm_id, - experiment_id=experiment_id, - user_id=user_id, - context_val=context_val, - draw_datetime_utc=datetime.now(timezone.utc), - ) - - asession.add(draw_db) - await asession.commit() - await asession.refresh(draw_db) - - return draw_db diff --git a/backend/app/contextual_mab/observation.py b/backend/app/contextual_mab/observation.py deleted file mode 100644 index e655bbf..0000000 --- a/backend/app/contextual_mab/observation.py +++ /dev/null @@ -1,126 +0,0 @@ -from datetime import datetime, timezone -from typing import Sequence - -import numpy as np -from fastapi.exceptions import HTTPException -from sqlalchemy.ext.asyncio import AsyncSession - -from ..schemas import ( - ObservationType, - RewardLikelihood, -) -from .models import ( - ContextualArmDB, - ContextualBanditDB, - ContextualDrawDB, - get_contextual_obs_by_experiment_arm_id, - save_contextual_obs_to_db, -) -from .sampling_utils import update_arm_params -from .schemas import ( - ContextualArmResponse, - ContextualBanditSample, -) - - -async def update_based_on_outcome( - experiment: ContextualBanditDB, - draw: ContextualDrawDB, - reward: float, - asession: AsyncSession, - observation_type: ObservationType, -) -> ContextualArmResponse: - """ - Update the arm based on the outcome of the draw. - - This is a helper function to allow `auto_fail` job to call - it as well. - """ - - update_experiment_metadata(experiment) - - arm = get_arm_from_experiment(experiment, draw.arm_id) - arm.n_outcomes += 1 - - # Ensure reward is binary for Bernoulli reward type - if experiment.reward_type == RewardLikelihood.BERNOULLI.value: - if reward not in [0, 1]: - raise HTTPException( - status_code=400, - detail="Reward must be 0 or 1 for Bernoulli reward type.", - ) - - # Get data for arm update - all_obs, contexts, rewards = await prepare_data_for_arm_update( - experiment.experiment_id, arm.arm_id, asession, draw, reward - ) - - experiment_data = ContextualBanditSample.model_validate(experiment) - mu, covariance = update_arm_params( - arm=ContextualArmResponse.model_validate(arm), - prior_type=experiment_data.prior_type, - reward_type=experiment_data.reward_type, - context=contexts, - reward=rewards, - ) - - await save_updated_data( - arm, mu, covariance, draw, reward, observation_type, asession - ) - - return ContextualArmResponse.model_validate(arm) - - -def update_experiment_metadata(experiment: ContextualBanditDB) -> None: - """Update experiment metadata with new trial information""" - experiment.n_trials += 1 - experiment.last_trial_datetime_utc = datetime.now(tz=timezone.utc) - - -def get_arm_from_experiment( - experiment: ContextualBanditDB, arm_id: int -) -> ContextualArmDB: - """Get and validate the arm from the experiment""" - arms = [a for a in experiment.arms if a.arm_id == arm_id] - if not arms: - raise HTTPException(status_code=404, detail=f"Arm with id {arm_id} not found") - return arms[0] - - -async def prepare_data_for_arm_update( - experiment_id: int, - arm_id: int, - asession: AsyncSession, - draw: ContextualDrawDB, - reward: float, -) -> tuple[Sequence[ContextualDrawDB], list[list], list[float]]: - """Prepare the data needed for updating arm parameters""" - all_obs = await get_contextual_obs_by_experiment_arm_id( - experiment_id=experiment_id, - arm_id=arm_id, - asession=asession, - ) - - rewards = [obs.reward for obs in all_obs] + [reward] - contexts = [obs.context_val for obs in all_obs] - contexts.append(draw.context_val) - - return all_obs, contexts, rewards - - -async def save_updated_data( - arm: ContextualArmDB, - mu: np.ndarray, - covariance: np.ndarray, - draw: ContextualDrawDB, - reward: float, - observation_type: ObservationType, - asession: AsyncSession, -) -> None: - """Save the updated arm and observation data""" - arm.mu = mu.tolist() - arm.covariance = covariance.tolist() - asession.add(arm) - await asession.commit() - - await save_contextual_obs_to_db(draw, reward, asession, observation_type) diff --git a/backend/app/contextual_mab/routers.py b/backend/app/contextual_mab/routers.py deleted file mode 100644 index 08eea28..0000000 --- a/backend/app/contextual_mab/routers.py +++ /dev/null @@ -1,395 +0,0 @@ -from typing import Annotated, List, Optional -from uuid import uuid4 - -from fastapi import APIRouter, Depends -from fastapi.exceptions import HTTPException -from sqlalchemy.ext.asyncio import AsyncSession - -from ..auth.dependencies import ( - authenticate_workspace_key, - get_verified_user, - require_admin_role, -) -from ..database import get_async_session -from ..models import get_notifications_from_db, save_notifications_to_db -from ..schemas import ( - ContextType, - NotificationsResponse, - ObservationType, - Outcome, -) -from ..users.models import UserDB -from ..utils import setup_logger -from ..workspaces.models import ( - WorkspaceDB, - get_user_default_workspace, -) -from .models import ( - ContextualBanditDB, - ContextualDrawDB, - delete_contextual_mab_by_id, - get_all_contextual_mabs, - get_all_contextual_obs_by_experiment_id, - get_contextual_mab_by_id, - get_draw_by_client_id, - get_draw_by_id, - save_contextual_mab_to_db, - save_draw_to_db, -) -from .observation import update_based_on_outcome -from .sampling_utils import choose_arm -from .schemas import ( - CMABDrawResponse, - CMABObservationResponse, - ContextInput, - ContextualArmResponse, - ContextualBandit, - ContextualBanditResponse, - ContextualBanditSample, -) - -router = APIRouter(prefix="/contextual_mab", tags=["Contextual Bandits"]) - -logger = setup_logger(__name__) - - -@router.post("/", response_model=ContextualBanditResponse) -async def create_contextual_mabs( - experiment: ContextualBandit, - user_db: Annotated[UserDB, Depends(require_admin_role)], - asession: AsyncSession = Depends(get_async_session), -) -> ContextualBanditResponse | HTTPException: - """ - Create a new contextual experiment with different priors for each context. - """ - workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - - if workspace_db is None: - raise HTTPException( - status_code=404, - detail="Workspace not found. Please create a workspace first.", - ) - - cmab = await save_contextual_mab_to_db( - experiment, user_db.user_id, workspace_db.workspace_id, asession - ) - notifications = await save_notifications_to_db( - experiment_id=cmab.experiment_id, - user_id=user_db.user_id, - notifications=experiment.notifications, - asession=asession, - ) - cmab_dict = cmab.to_dict() - cmab_dict["notifications"] = [n.to_dict() for n in notifications] - return ContextualBanditResponse.model_validate(cmab_dict) - - -@router.get("/", response_model=list[ContextualBanditResponse]) -async def get_contextual_mabs( - user_db: Annotated[UserDB, Depends(get_verified_user)], - asession: AsyncSession = Depends(get_async_session), -) -> list[ContextualBanditResponse]: - """ - Get details of all experiments. - """ - workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - - if workspace_db is None: - raise HTTPException( - status_code=404, - detail="Workspace not found. Please create a workspace first.", - ) - - experiments = await get_all_contextual_mabs(workspace_db.workspace_id, asession) - all_experiments = [] - for exp in experiments: - exp_dict = exp.to_dict() - exp_dict["notifications"] = [ - n.to_dict() - for n in await get_notifications_from_db( - exp.experiment_id, exp.user_id, asession - ) - ] - all_experiments.append( - ContextualBanditResponse.model_validate( - { - **exp_dict, - "notifications": [ - NotificationsResponse.model_validate(n) - for n in exp_dict["notifications"] - ], - } - ) - ) - - return all_experiments - - -@router.get("/{experiment_id}", response_model=ContextualBanditResponse) -async def get_contextual_mab( - experiment_id: int, - user_db: Annotated[UserDB, Depends(get_verified_user)], - asession: AsyncSession = Depends(get_async_session), -) -> ContextualBanditResponse | HTTPException: - """ - Get details of experiment with the provided `experiment_id`. - """ - workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - - if workspace_db is None: - raise HTTPException( - status_code=404, - detail="Workspace not found. Please create a workspace first.", - ) - - experiment = await get_contextual_mab_by_id( - experiment_id, workspace_db.workspace_id, asession - ) - if experiment is None: - raise HTTPException( - status_code=404, detail=f"Experiment with id {experiment_id} not found" - ) - - experiment_dict = experiment.to_dict() - experiment_dict["notifications"] = [ - n.to_dict() - for n in await get_notifications_from_db( - experiment.experiment_id, experiment.user_id, asession - ) - ] - return ContextualBanditResponse.model_validate(experiment_dict) - - -@router.delete("/{experiment_id}", response_model=dict) -async def delete_contextual_mab( - experiment_id: int, - user_db: Annotated[UserDB, Depends(require_admin_role)], - asession: AsyncSession = Depends(get_async_session), -) -> dict: - """ - Delete the experiment with the provided `experiment_id`. - """ - try: - workspace_db = await get_user_default_workspace( - asession=asession, user_db=user_db - ) - - if workspace_db is None: - raise HTTPException( - status_code=404, - detail="Workspace not found. Please create a workspace first.", - ) - - experiment = await get_contextual_mab_by_id( - experiment_id, workspace_db.workspace_id, asession - ) - if experiment is None: - raise HTTPException( - status_code=404, detail=f"Experiment with id {experiment_id} not found" - ) - await delete_contextual_mab_by_id( - experiment_id, workspace_db.workspace_id, asession - ) - return {"detail": f"Experiment {experiment_id} deleted successfully."} - except Exception as e: - raise HTTPException(status_code=500, detail=f"Error: {e}") from e - - -@router.post("/{experiment_id}/draw", response_model=CMABDrawResponse) -async def draw_arm( - experiment_id: int, - context: List[ContextInput], - draw_id: Optional[str] = None, - client_id: Optional[str] = None, - workspace_db: WorkspaceDB = Depends(authenticate_workspace_key), - asession: AsyncSession = Depends(get_async_session), -) -> CMABDrawResponse: - """ - Get which arm to pull next for provided experiment. - """ - # Get workspace from user context - workspace_id = workspace_db.workspace_id - - experiment = await get_contextual_mab_by_id(experiment_id, workspace_id, asession) - - if experiment is None: - raise HTTPException( - status_code=404, detail=f"Experiment with id {experiment_id} not found" - ) - - # Check context inputs - if len(experiment.contexts) != len(context): - raise HTTPException( - status_code=400, - detail="Number of contexts provided does not match the num contexts.", - ) - experiment_data = ContextualBanditSample.model_validate(experiment) - sorted_context = list(sorted(context, key=lambda x: x.context_id)) - - try: - for c_input, c_exp in zip( - sorted_context, - sorted(experiment.contexts, key=lambda x: x.context_id), - ): - if c_exp.value_type == ContextType.BINARY.value: - Outcome(c_input.context_value) - except ValueError as e: - raise HTTPException( - status_code=400, - detail=f"Invalid context value: {e}", - ) from e - - # Generate UUID if not provided - if draw_id is None: - draw_id = str(uuid4()) - - existing_draw = await get_draw_by_id(draw_id, asession) - if existing_draw: - raise HTTPException( - status_code=400, - detail=f"Draw ID {draw_id} already exists.", - ) - - # Check if sticky assignment - if experiment.sticky_assignment and not client_id: - raise HTTPException( - status_code=400, - detail="Client ID is required for sticky assignment.", - ) - - chosen_arm = choose_arm( - experiment_data, - [c.context_value for c in sorted_context], - ) - chosen_arm_id = experiment.arms[chosen_arm].arm_id - if experiment.sticky_assignment and client_id: - previous_draw = await get_draw_by_client_id( - client_id=client_id, - experiment_id=experiment.experiment_id, - asession=asession, - ) - if previous_draw: - chosen_arm_id = previous_draw.arm_id - - try: - _ = await save_draw_to_db( - experiment_id=experiment.experiment_id, - arm_id=chosen_arm_id, - context_val=[c.context_value for c in sorted_context], - draw_id=draw_id, - client_id=client_id, - user_id=None, - asession=asession, - workspace_id=workspace_id, - ) - except Exception as e: - raise HTTPException( - status_code=500, - detail=f"Error saving draw to database: {e}", - ) from e - - return CMABDrawResponse.model_validate( - { - "draw_id": draw_id, - "client_id": client_id, - "arm": ContextualArmResponse.model_validate( - [arm for arm in experiment.arms if arm.arm_id == chosen_arm_id][0] - ), - } - ) - - -@router.put("/{experiment_id}/{draw_id}/{reward}", response_model=ContextualArmResponse) -async def update_arm( - experiment_id: int, - draw_id: str, - reward: float, - workspace_db: WorkspaceDB = Depends(authenticate_workspace_key), - asession: AsyncSession = Depends(get_async_session), -) -> ContextualArmResponse: - """ - Update the arm with the provided `arm_id` for the given - `experiment_id` based on the reward. - """ - # Get workspace from user context - workspace_id = workspace_db.workspace_id - - # Get the experiment and do checks - experiment, draw = await validate_experiment_and_draw( - experiment_id, draw_id, workspace_id, asession - ) - - return await update_based_on_outcome( - experiment, draw, reward, asession, ObservationType.USER - ) - - -@router.get( - "/{experiment_id}/outcomes", - response_model=list[CMABObservationResponse], -) -async def get_outcomes( - experiment_id: int, - workspace_db: WorkspaceDB = Depends(authenticate_workspace_key), - asession: AsyncSession = Depends(get_async_session), -) -> list[CMABObservationResponse]: - """ - Get the outcomes for the experiment. - """ - # Get workspace from user context - workspace_id = workspace_db.workspace_id - - experiment = await get_contextual_mab_by_id(experiment_id, workspace_id, asession) - if not experiment: - raise HTTPException( - status_code=404, detail=f"Experiment with id {experiment_id} not found" - ) - - observations = await get_all_contextual_obs_by_experiment_id( - experiment_id=experiment.experiment_id, - workspace_id=workspace_id, - asession=asession, - ) - return [CMABObservationResponse.model_validate(obs) for obs in observations] - - -async def validate_experiment_and_draw( - experiment_id: int, - draw_id: str, - workspace_id: int, - asession: AsyncSession, -) -> tuple[ContextualBanditDB, ContextualDrawDB]: - """ - Validate that the experiment exists in the workspace - and the draw exists for that experiment. - """ - experiment = await get_contextual_mab_by_id( - experiment_id=experiment_id, - workspace_id=workspace_id, - asession=asession, - ) - if experiment is None: - raise HTTPException( - status_code=404, detail=f"Experiment with id {experiment_id} not found" - ) - - draw = await get_draw_by_id(draw_id=draw_id, asession=asession) - if draw is None: - raise HTTPException(status_code=404, detail=f"Draw with id {draw_id} not found") - - if draw.experiment_id != experiment_id: - raise HTTPException( - status_code=400, - detail=( - f"Draw with id {draw_id} does not belong " - f"to experiment with id {experiment_id}", - ), - ) - - if draw.reward is not None: - raise HTTPException( - status_code=400, - detail=f"Draw with id {draw_id} already has a reward.", - ) - - return experiment, draw diff --git a/backend/app/contextual_mab/sampling_utils.py b/backend/app/contextual_mab/sampling_utils.py deleted file mode 100644 index 03c9784..0000000 --- a/backend/app/contextual_mab/sampling_utils.py +++ /dev/null @@ -1,172 +0,0 @@ -import numpy as np -from scipy.optimize import minimize - -from ..schemas import ArmPriors, ContextLinkFunctions, RewardLikelihood -from .schemas import ContextualArmResponse, ContextualBanditSample - - -def sample_normal( - mus: list[np.ndarray], - covariances: list[np.ndarray], - context: np.ndarray, - link_function: ContextLinkFunctions, -) -> int: - """ - Thompson Sampling with normal prior. - - Parameters - ---------- - mus: mean of Normal distribution for each arm - covariances: covariance matrix of Normal distribution for each arm - context: context vector - link_function: link function for the context - """ - samples = np.array( - [ - np.random.multivariate_normal(mean=mu, cov=cov) - for mu, cov in zip(mus, covariances) - ] - ).reshape(-1, len(context)) - probs = link_function(samples @ context) - return int(probs.argmax()) - - -def update_arm_normal( - current_mu: np.ndarray, - current_covariance: np.ndarray, - reward: float, - context: np.ndarray, - sigma_llhood: float, -) -> tuple[np.ndarray, np.ndarray]: - """ - Update the mean and covariance of the normal distribution. - - Parameters - ---------- - current_mu : The mean of the normal distribution. - current_covariance : The covariance matrix of the normal distribution. - reward : The reward of the arm. - context : The context vector. - sigma_llhood : The stddev of the likelihood. - """ - new_covariance_inv = ( - np.linalg.inv(current_covariance) + (context.T @ context) / sigma_llhood**2 - ) - new_covariance = np.linalg.inv(new_covariance_inv) - - new_mu = new_covariance @ ( - np.linalg.inv(current_covariance) @ current_mu - + context * reward / sigma_llhood**2 - ) - return new_mu, new_covariance - - -def update_arm_laplace( - current_mu: np.ndarray, - current_covariance: np.ndarray, - reward: np.ndarray, - context: np.ndarray, - link_function: ContextLinkFunctions, - reward_likelihood: RewardLikelihood, - prior_type: ArmPriors, -) -> tuple[np.ndarray, np.ndarray]: - """ - Update the mean and covariance using the Laplace approximation. - - Parameters - ---------- - current_mu : The mean of the normal distribution. - current_covariance : The covariance matrix of the normal distribution. - reward : The list of rewards for the arm. - context : The list of contexts for the arm. - link_function : The link function for parameters to rewards. - reward_likelihood : The likelihood function of the reward. - prior_type : The prior type of the arm. - """ - - def objective(theta: np.ndarray) -> float: - """ - Objective function for the Laplace approximation. - - Parameters - ---------- - theta : The parameters of the arm. - """ - # Log prior - log_prior = prior_type(theta, mu=current_mu, covariance=current_covariance) - - # Log likelihood - log_likelihood = reward_likelihood(reward, link_function(context @ theta)) - - return -log_prior - log_likelihood - - result = minimize(objective, current_mu, method="L-BFGS-B", hess="2-point") - new_mu = result.x - covariance = result.hess_inv.todense() # type: ignore - - new_covariance = 0.5 * (covariance + covariance.T) - return new_mu, new_covariance.astype(np.float64) - - -def choose_arm(experiment: ContextualBanditSample, context: list[float]) -> int: - """ - Choose the arm with the highest probability. - - Parameters - ---------- - experiment : The experiment object. - context : The context vector. - """ - link_function = ( - ContextLinkFunctions.NONE - if experiment.reward_type == RewardLikelihood.NORMAL - else ContextLinkFunctions.LOGISTIC - ) - return sample_normal( - mus=[np.array(arm.mu) for arm in experiment.arms], - covariances=[np.array(arm.covariance) for arm in experiment.arms], - context=np.array(context), - link_function=link_function, - ) - - -def update_arm_params( - arm: ContextualArmResponse, - prior_type: ArmPriors, - reward_type: RewardLikelihood, - reward: list, - context: list, -) -> tuple[np.ndarray, np.ndarray]: - """ - Update the arm parameters. - - Parameters - ---------- - arm : The arm object. - prior_type : The prior type of the arm. - reward_type : The reward type of the arm. - reward : All rewards for the arm. - context : All context vectors for the arm. - """ - if (prior_type == ArmPriors.NORMAL) and (reward_type == RewardLikelihood.NORMAL): - return update_arm_normal( - current_mu=np.array(arm.mu), - current_covariance=np.array(arm.covariance), - reward=reward[-1], - context=np.array(context[-1]), - sigma_llhood=1.0, # TODO: need to implement likelihood stddev - ) - elif (prior_type == ArmPriors.NORMAL) and ( - reward_type == RewardLikelihood.BERNOULLI - ): - return update_arm_laplace( - current_mu=np.array(arm.mu), - current_covariance=np.array(arm.covariance), - reward=np.array(reward), - context=np.array(context), - link_function=ContextLinkFunctions.LOGISTIC, - reward_likelihood=RewardLikelihood.BERNOULLI, - prior_type=ArmPriors.NORMAL, - ) - else: - raise ValueError("Prior and reward type combination is not supported.") diff --git a/backend/app/contextual_mab/schemas.py b/backend/app/contextual_mab/schemas.py deleted file mode 100644 index 57baaf4..0000000 --- a/backend/app/contextual_mab/schemas.py +++ /dev/null @@ -1,268 +0,0 @@ -from datetime import datetime -from typing import Optional, Self - -from pydantic import BaseModel, ConfigDict, Field, model_validator - -from ..schemas import ( - ArmPriors, - AutoFailUnitType, - ContextType, - Notifications, - NotificationsResponse, - RewardLikelihood, - allowed_combos_cmab, -) - - -class Context(BaseModel): - """ - Pydantic model for a binary-valued context of the experiment. - """ - - name: str = Field( - description="Name of the context", - examples=["Context 1"], - ) - description: str = Field( - description="Description of the context", - examples=["This is a description of the context."], - ) - value_type: ContextType = Field( - description="Type of value the context can take", default=ContextType.BINARY - ) - model_config = ConfigDict(from_attributes=True) - - -class ContextResponse(Context): - """ - Pydantic model for an response for context creation - """ - - context_id: int - model_config = ConfigDict(from_attributes=True) - - -class ContextInput(BaseModel): - """ - Pydantic model for a context input - """ - - context_id: int - context_value: float - model_config = ConfigDict(from_attributes=True) - - -class ContextualArm(BaseModel): - """ - Pydantic model for a contextual arm of the experiment. - """ - - name: str = Field( - max_length=150, - examples=["Arm 1"], - ) - description: str = Field( - max_length=500, - examples=["This is a description of the arm."], - ) - - mu_init: float = Field( - default=0.0, - examples=[0.0, 1.2, 5.7], - description="Mean parameter for Normal prior", - ) - - sigma_init: float = Field( - default=1.0, - examples=[1.0, 0.5, 2.0], - description="Standard deviation parameter for Normal prior", - ) - n_outcomes: Optional[int] = Field( - default=0, - description="Number of outcomes for the arm", - examples=[0, 10, 15], - ) - - @model_validator(mode="after") - def check_values(self) -> Self: - """ - Check if the values are unique and set new attributes. - """ - sigma = self.sigma_init - if sigma is not None and sigma <= 0: - raise ValueError("Std dev must be greater than 0.") - return self - - model_config = ConfigDict(from_attributes=True) - - -class ContextualArmResponse(ContextualArm): - """ - Pydantic model for an response for contextual arm creation - """ - - arm_id: int - mu: list[float] - covariance: list[list[float]] - - model_config = ConfigDict(from_attributes=True) - - -class ContextualBanditBase(BaseModel): - """ - Pydantic model for a contextual experiment - Base model. - Note: Do not use this model directly. Use ContextualBandit instead. - """ - - name: str = Field( - max_length=150, - examples=["Experiment 1"], - ) - - description: str = Field( - max_length=500, - examples=["This is a description of the experiment."], - ) - - sticky_assignment: bool = Field( - description="Whether the arm assignment is sticky or not.", - default=False, - ) - - auto_fail: bool = Field( - description=( - "Whether the experiment should fail automatically after " - "a certain period if no outcome is registered." - ), - default=False, - ) - - auto_fail_value: Optional[int] = Field( - description="The time period after which the experiment should fail.", - default=None, - ) - - auto_fail_unit: Optional[AutoFailUnitType] = Field( - description="The time unit for the auto fail period.", - default=None, - ) - - reward_type: RewardLikelihood = Field( - description="The type of reward we observe from the experiment.", - default=RewardLikelihood.BERNOULLI, - ) - - prior_type: ArmPriors = Field( - description="The type of prior distribution for the arms.", - default=ArmPriors.NORMAL, - ) - - is_active: bool = True - - model_config = ConfigDict(from_attributes=True) - - -class ContextualBandit(ContextualBanditBase): - """ - Pydantic model for a contextual experiment. - """ - - arms: list[ContextualArm] - contexts: list[Context] - notifications: Notifications - - @model_validator(mode="after") - def auto_fail_unit_and_value_set(self) -> Self: - """ - Validate that the auto fail unit and value are set if auto fail is True. - """ - if self.auto_fail: - if ( - not self.auto_fail_value - or not self.auto_fail_unit - or self.auto_fail_value <= 0 - ): - raise ValueError( - ( - "Auto fail is enabled. " - "Please provide both auto_fail_value and auto_fail_unit." - ) - ) - return self - - @model_validator(mode="after") - def arms_at_least_two(self) -> Self: - """ - Validate that the experiment has at least two arms. - """ - if len(self.arms) < 2: - raise ValueError("The experiment must have at least two arms.") - return self - - @model_validator(mode="after") - def check_prior_reward_type_combo(self) -> Self: - """ - Validate that the prior and reward type combination is allowed. - """ - - if (self.prior_type, self.reward_type) not in allowed_combos_cmab: - raise ValueError("Prior and reward type combo not supported.") - return self - - model_config = ConfigDict(from_attributes=True) - - -class ContextualBanditResponse(ContextualBanditBase): - """ - Pydantic model for an response for contextual experiment creation. - Returns the id of the experiment, the arms and the contexts - """ - - experiment_id: int - workspace_id: int - arms: list[ContextualArmResponse] - contexts: list[ContextResponse] - notifications: list[NotificationsResponse] - created_datetime_utc: datetime - last_trial_datetime_utc: Optional[datetime] = None - n_trials: int - - model_config = ConfigDict(from_attributes=True) - - -class ContextualBanditSample(ContextualBanditBase): - """ - Pydantic model for a contextual experiment sample. - """ - - experiment_id: int - arms: list[ContextualArmResponse] - contexts: list[ContextResponse] - - -class CMABObservationResponse(BaseModel): - """ - Pydantic model for an response for contextual observation creation - """ - - arm_id: int - reward: float - context_val: list[float] - - draw_id: str - client_id: str | None - observed_datetime_utc: datetime - - model_config = ConfigDict(from_attributes=True) - - -class CMABDrawResponse(BaseModel): - """ - Pydantic model for an response for contextual arm draw - """ - - draw_id: str - client_id: str | None - arm: ContextualArmResponse - - model_config = ConfigDict(from_attributes=True) diff --git a/backend/app/experiments/dependencies.py b/backend/app/experiments/dependencies.py new file mode 100644 index 0000000..249a989 --- /dev/null +++ b/backend/app/experiments/dependencies.py @@ -0,0 +1,269 @@ +from datetime import datetime, timezone +from typing import Union + +import numpy as np +from fastapi.exceptions import HTTPException +from sqlalchemy.ext.asyncio import AsyncSession + +from .models import ( + ArmDB, + DrawDB, + ExperimentDB, + get_draw_by_id, + get_draws_with_rewards_by_experiment_id, + get_experiment_by_id_from_db, + get_notifications_from_db, + save_observation_to_db, +) +from .sampling_utils import update_arm +from .schemas import ( + ArmPriors, + ArmResponse, + ExperimentSample, + ExperimentsEnum, + NotificationsResponse, + ObservationType, + Outcome, + RewardLikelihood, +) + + +async def experiments_db_to_schema( + experiments_db: list[ExperimentDB], + asession: AsyncSession, +) -> list[ExperimentSample]: + """ + Convert a list of ExperimentDB objects to a list of ExperimentResponse schemas. + """ + all_experiments = [] + for exp in experiments_db: + exp_dict = exp.to_dict() + exp_dict["notifications"] = [ + n.to_dict() + for n in await get_notifications_from_db( + experiment_id=exp.experiment_id, + user_id=exp.user_id, + workspace_id=exp.workspace_id, + asession=asession, + ) + ] + all_experiments.append( + ExperimentSample.model_validate( + { + **exp_dict, + "notifications": [ + NotificationsResponse(**n) for n in exp_dict["notifications"] + ], + } + ) + ) + + return all_experiments + + +async def validate_experiment_and_draw( + experiment_id: int, draw_id: str, workspace_id: int, asession: AsyncSession +) -> tuple[ExperimentDB, DrawDB]: + """ + Validate the experiment and draw. + """ + experiment = await get_experiment_by_id_from_db( + workspace_id=workspace_id, experiment_id=experiment_id, asession=asession + ) + # Check experiment + if experiment is None: + raise HTTPException( + status_code=404, detail=f"Experiment with id {experiment_id} not found" + ) + + draw = await get_draw_by_id(draw_id=draw_id, asession=asession) + # Check draw + if draw is None: + raise HTTPException(status_code=404, detail=f"Draw with id {draw_id} not found") + if draw.experiment_id != experiment_id: + raise HTTPException( + status_code=404, + detail=( + f"Draw with id {draw_id} does not belong to " + f"experiment with id {experiment_id}" + ), + ) + if draw.reward: + raise HTTPException( + status_code=400, + detail=f"Draw with id {draw_id} has already been updated with a reward.", + ) + + return experiment, draw + + +async def format_rewards_for_arm_update( + experiment: ExperimentDB, + chosen_arm_id: int, + reward: float, + context_val: Union[list[float], None], + asession: AsyncSession, +) -> tuple[list[float], list[list[float]] | None, list[float] | None]: + """ + Format the rewards for the arm update. + """ + previous_rewards = await get_draws_with_rewards_by_experiment_id( + experiment_id=experiment.experiment_id, asession=asession + ) + + rewards = [] + treatments = None + contexts = None + + if previous_rewards: + if experiment.exp_type != ExperimentsEnum.BAYESAB.value: + rewards = [ + draw.reward for draw in previous_rewards if draw.arm_id == chosen_arm_id + ] + else: + treatments = [] + for draw in previous_rewards: + rewards.append(draw.reward) + treatments.append( + [ + float(arm.is_treatment_arm) + for arm in experiment.arms + if arm.arm_id == draw.arm_id + ][0] + ) + + if experiment.exp_type == ExperimentsEnum.CMAB.value: + contexts = [] + for draw in previous_rewards: + if draw.context_val: + contexts.append(draw.context_val) + else: + raise ValueError( + f"Context value is missing for draw id {draw.draw_id}" + f" in CMAB experiment {draw.experiment_id}." + ) + + rewards_list = [reward] if rewards is None else [reward] + rewards + + context_list = None if not context_val else [context_val] + if contexts and context_list: + context_list = context_list + contexts + + chosen_arm_index = int( + np.argwhere([a.arm_id == chosen_arm_id for a in experiment.arms])[0][0] + ) + new_treatment = [float(experiment.arms[chosen_arm_index].is_treatment_arm)] + treatments_list = ( + new_treatment if treatments is None else new_treatment + treatments + ) + + return rewards_list, context_list, treatments_list + + +async def update_arm_based_on_outcome( + experiment: ExperimentDB, + draw: DrawDB, + rewards: list[float], + contexts: Union[list[list[float]], None], + treatments: Union[list[float], None], +) -> ArmResponse: + """ + Update the arm parameters based on the outcome. + + This is a helper function to allow `auto_fail` job to call + it as well. + """ + update_experiment_metadata(experiment) + + arm = get_arm_from_experiment(experiment, draw.arm_id) + arm.n_outcomes += 1 + + chosen_arm = int( + np.argwhere([a.arm_id == arm.arm_id for a in experiment.arms])[0][0] + ) + + await update_arm_parameters( + arm=arm, + experiment=experiment, + chosen_arm=chosen_arm, + rewards=rewards, + contexts=contexts, + treatments=treatments, + ) + + return ArmResponse.model_validate(arm) + + +def update_experiment_metadata(experiment: ExperimentDB) -> None: + """Update experiment metadata with new trial information""" + experiment.n_trials += 1 + experiment.last_trial_datetime_utc = datetime.now(tz=timezone.utc) + + +def get_arm_from_experiment(experiment: ExperimentDB, arm_id: int) -> ArmDB: + """Get and validate the arm from the experiment""" + arms = [a for a in experiment.arms if a.arm_id == arm_id] + if not arms: + raise HTTPException(status_code=404, detail=f"Arm with id {arm_id} not found") + return arms[0] + + +async def update_arm_parameters( + arm: ArmDB, + experiment: ExperimentDB, + chosen_arm: int, + rewards: list[float], + contexts: Union[list[list[float]], None], + treatments: Union[list[float], None], +) -> None: + """Update the arm parameters based on the reward type and outcome""" + experiment_data = ExperimentSample.model_validate(experiment.to_dict()) + if experiment_data.reward_type == RewardLikelihood.BERNOULLI: + Outcome(rewards[0]) # Check if reward is 0 or 1 + params = update_arm( + experiment=experiment_data, + rewards=rewards, + arm_to_update=chosen_arm, + context=contexts, + treatments=treatments, + ) + + if experiment_data.exp_type == ExperimentsEnum.BAYESAB: + if experiment_data.prior_type == ArmPriors.NORMAL: + mus, covariances = params + for arm in experiment.arms: + if arm.is_treatment_arm: + arm.mu = [mus[0]] + arm.covariance = covariances[0] + else: + arm.mu = [mus[1]] + arm.covariance = covariances[1] + else: + raise HTTPException( + status_code=400, + detail="Prior type not supported for Bayesian A/B experiments.", + ) + else: + if experiment_data.prior_type == ArmPriors.BETA: + arm.alpha, arm.beta = params + elif experiment_data.prior_type == ArmPriors.NORMAL: + arm.mu, arm.covariance = params + else: + raise HTTPException( + status_code=400, + detail="Prior type not supported.", + ) + + +async def save_updated_data( + arm: ArmDB, + draw: DrawDB, + reward: float, + observation_type: ObservationType, + asession: AsyncSession, +) -> None: + """Save the updated arm and observation data""" + await asession.commit() + await save_observation_to_db( + draw=draw, reward=reward, observation_type=observation_type, asession=asession + ) diff --git a/backend/app/experiments/models.py b/backend/app/experiments/models.py new file mode 100644 index 0000000..48cb938 --- /dev/null +++ b/backend/app/experiments/models.py @@ -0,0 +1,736 @@ +import uuid +from datetime import datetime, timezone +from typing import Optional, Sequence + +import numpy as np +from sqlalchemy import ( + Boolean, + DateTime, + Enum, + Float, + ForeignKey, + Integer, + String, + delete, + select, +) +from sqlalchemy.dialects.postgresql import ARRAY +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from ..models import Base +from .schemas import ( + AutoFailUnitType, + EventType, + Experiment, + Notifications, + ObservationType, +) + + +# --- Base model for experiments --- +class ExperimentDB(Base): + """ + Base model for experiments. + """ + + __tablename__ = "experiments" + + # IDs + experiment_id: Mapped[int] = mapped_column( + Integer, primary_key=True, nullable=False + ) + user_id: Mapped[int] = mapped_column( + Integer, ForeignKey("users.user_id"), nullable=False + ) + workspace_id: Mapped[int] = mapped_column( + Integer, ForeignKey("workspace.workspace_id"), nullable=False + ) + + # Description + name: Mapped[str] = mapped_column(String(length=150), nullable=False) + description: Mapped[str] = mapped_column(String(length=500), nullable=False) + is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) + + # Assignments config + sticky_assignment: Mapped[bool] = mapped_column( + Boolean, nullable=False, default=False + ) + auto_fail: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + auto_fail_value: Mapped[int] = mapped_column(Integer, nullable=True) + auto_fail_unit: Mapped[AutoFailUnitType] = mapped_column( + Enum(AutoFailUnitType), nullable=True + ) + + # Experiment config + exp_type: Mapped[str] = mapped_column(String(length=50), nullable=False) + prior_type: Mapped[str] = mapped_column(String(length=50), nullable=False) + reward_type: Mapped[str] = mapped_column(String(length=50), nullable=False) + + # State variables + created_datetime_utc: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False + ) + n_trials: Mapped[int] = mapped_column(Integer, nullable=False) + last_trial_datetime_utc: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=True + ) + + # Relationships + arms: Mapped[list["ArmDB"]] = relationship( + "ArmDB", back_populates="experiment", lazy="joined" + ) + draws: Mapped[list["DrawDB"]] = relationship( + "DrawDB", + back_populates="experiment", + primaryjoin="ExperimentDB.experiment_id==DrawDB.experiment_id", + lazy="joined", + ) + clients: Mapped[list["ClientDB"]] = relationship( + "ClientDB", + back_populates="experiment", + lazy="joined", + ) + contexts: Mapped[Optional[list["ContextDB"]]] = relationship( + "ContextDB", + back_populates="experiment", + lazy="joined", + primaryjoin="and_(ExperimentDB.experiment_id==ContextDB.experiment_id," + + "ExperimentDB.exp_type=='cmab')", + ) + + def __repr__(self) -> str: + """ + String representation of the model + """ + return f"" + + @property + def has_contexts(self) -> bool: + """Check if this experiment type supports contexts.""" + return self.exp_type == "cmab" + + @property + def context_list(self) -> list["ContextDB"] | list: + """Get contexts, returning empty list if not applicable.""" + return self.contexts if self.has_contexts and self.contexts is not None else [] + + def to_dict(self) -> dict: + """ + Convert the ORM object to a dictionary. + """ + return { + "experiment_id": self.experiment_id, + "user_id": self.user_id, + "workspace_id": self.workspace_id, + "name": self.name, + "description": self.description, + "sticky_assignment": self.sticky_assignment, + "auto_fail": self.auto_fail, + "auto_fail_value": self.auto_fail_value, + "auto_fail_unit": self.auto_fail_unit, + "exp_type": self.exp_type, + "prior_type": self.prior_type, + "reward_type": self.reward_type, + "created_datetime_utc": str(self.created_datetime_utc), + "is_active": self.is_active, + "n_trials": self.n_trials, + "last_trial_datetime_utc": str(self.last_trial_datetime_utc), + "arms": [arm.to_dict() for arm in self.arms], + "draws": [draw.to_dict() for draw in self.draws], + "contexts": ( + [context.to_dict() for context in self.context_list if context] + if len(self.context_list) > 0 + else [] + ), + } + + +# --- Arm model --- +class ArmDB(Base): + """ + Base model for arms. + """ + + __tablename__ = "arms" + + # IDs + arm_id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False) + workspace_id: Mapped[int] = mapped_column( + Integer, ForeignKey("workspace.workspace_id"), nullable=False + ) + experiment_id: Mapped[int] = mapped_column( + Integer, ForeignKey("experiments.experiment_id"), nullable=False + ) + + # Description + name: Mapped[str] = mapped_column(String(length=150), nullable=False) + description: Mapped[str] = mapped_column(String(length=500), nullable=False) + n_outcomes: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + + # Prior variables + mu_init: Mapped[Optional[float]] = mapped_column(Float, nullable=True) + sigma_init: Mapped[Optional[float]] = mapped_column(Float, nullable=True) + mu: Mapped[Optional[list[float]]] = mapped_column(ARRAY(Float), nullable=True) + covariance: Mapped[Optional[list[float]]] = mapped_column( + ARRAY(Float), nullable=True + ) + is_treatment_arm: Mapped[bool] = mapped_column(Boolean, nullable=True, default=True) + + alpha_init: Mapped[Optional[float]] = mapped_column(Float, nullable=True) + beta_init: Mapped[Optional[float]] = mapped_column(Float, nullable=True) + alpha: Mapped[Optional[float]] = mapped_column(Float, nullable=True) + beta: Mapped[Optional[float]] = mapped_column(Float, nullable=True) + + # Relationships + experiment: Mapped[ExperimentDB] = relationship( + "ExperimentDB", back_populates="arms", lazy="joined" + ) + draws: Mapped[list["DrawDB"]] = relationship( + "DrawDB", + back_populates="arm", + lazy="joined", + ) + + def to_dict(self) -> dict: + """ + Convert the ORM object to a dictionary. + """ + return { + "arm_id": self.arm_id, + "experiment_id": self.experiment_id, + "name": self.name, + "description": self.description, + "alpha": self.alpha, + "beta": self.beta, + "mu": self.mu, + "covariance": self.covariance, + "alpha_init": self.alpha_init, + "beta_init": self.beta_init, + "mu_init": self.mu_init, + "sigma_init": self.sigma_init, + "draws": [draw.to_dict() for draw in self.draws], + "n_outcomes": self.n_outcomes, + } + + +# --- Draw model --- +class DrawDB(Base): + """ + Base model for draws. + """ + + __tablename__ = "draws" + + # IDs + draw_id: Mapped[str] = mapped_column( + String, primary_key=True, default=lambda x: str(uuid.uuid4()) + ) + arm_id: Mapped[int] = mapped_column( + Integer, ForeignKey("arms.arm_id"), nullable=False + ) + experiment_id: Mapped[int] = mapped_column( + Integer, ForeignKey("experiments.experiment_id"), nullable=False + ) + workspace_id: Mapped[int] = mapped_column( + Integer, ForeignKey("workspace.workspace_id"), nullable=False + ) + client_id: Mapped[str] = mapped_column( + String(length=36), ForeignKey("clients.client_id"), nullable=True + ) + + # Logging + draw_datetime_utc: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + nullable=False, + ) + observed_datetime_utc: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=True + ) + observation_type: Mapped[ObservationType] = mapped_column( + Enum(ObservationType), nullable=True + ) + reward: Mapped[float] = mapped_column(Float, nullable=True) + context_val: Mapped[Optional[list[float]]] = mapped_column( + ARRAY(Float), nullable=True + ) + + # Relationships + arm: Mapped[ArmDB] = relationship("ArmDB", back_populates="draws", lazy="joined") + experiment: Mapped[ExperimentDB] = relationship( + "ExperimentDB", back_populates="draws", lazy="joined" + ) + client: Mapped[Optional["ClientDB"]] = relationship( + "ClientDB", + back_populates="draws", + lazy="joined", + primaryjoin="DrawDB.client_id==ClientDB.client_id", # noqa: E501 + ) + + def to_dict(self) -> dict: + """ + Convert the ORM object to a dictionary. + """ + return { + "draw_id": self.draw_id, + "arm_id": self.arm_id, + "experiment_id": self.experiment_id, + "client_id": self.client_id, + "draw_datetime_utc": self.draw_datetime_utc, + "observed_datetime_utc": self.observed_datetime_utc, + "observation_type": self.observation_type, + "reward": self.reward, + "context_val": self.context_val, + } + + +# --- Context model --- +class ContextDB(Base): + """ + ORM for managing context for an experiment + """ + + __tablename__ = "context" + + # IDs + context_id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False) + experiment_id: Mapped[int] = mapped_column( + Integer, ForeignKey("experiments.experiment_id"), nullable=False + ) + workspace_id: Mapped[int] = mapped_column( + Integer, ForeignKey("workspace.workspace_id"), nullable=False + ) + + # Description + name: Mapped[str] = mapped_column(String(length=150), nullable=False) + description: Mapped[str] = mapped_column(String(length=500), nullable=True) + value_type: Mapped[str] = mapped_column(String(length=50), nullable=False) + + # Relationships + experiment: Mapped[ExperimentDB] = relationship( + "ExperimentDB", back_populates="contexts", lazy="joined" + ) + + def to_dict(self) -> dict: + """ + Convert the ORM object to a dictionary. + """ + return { + "context_id": self.context_id, + "name": self.name, + "description": self.description, + "value_type": self.value_type, + } + + +# --- Client model --- +class ClientDB(Base): + """ + ORM for managing clients for an experiment + """ + + __tablename__ = "clients" + + # IDs + client_id: Mapped[str] = mapped_column( + String, primary_key=True, default=lambda x: str(uuid.uuid4()) + ) + experiment_id: Mapped[int] = mapped_column( + Integer, ForeignKey("experiments.experiment_id"), nullable=False + ) + workspace_id: Mapped[int] = mapped_column( + Integer, ForeignKey("workspace.workspace_id"), nullable=False + ) + + # Relationships + draws: Mapped[list[DrawDB]] = relationship( + "DrawDB", + back_populates="client", + lazy="joined", + ) + experiment: Mapped[ExperimentDB] = relationship( + "ExperimentDB", + back_populates="clients", + lazy="joined", + primaryjoin="and_(ClientDB.experiment_id==ExperimentDB.experiment_id," + + "ExperimentDB.sticky_assignment == True)", + ) + + def to_dict(self) -> dict: + """ + Convert the ORM object to a dictionary. + """ + return { + "client_id": self.client_id, + "experiment_id": self.experiment_id, + "workspace_id": self.workspace_id, + "draws": [draw.to_dict() for draw in self.draws], + } + + +# --- Notifications model --- +class NotificationsDB(Base): + """ + Model for notifications. + Note: if you are updating this, you should also update models in + the background celery job + """ + + __tablename__ = "notifications" + + notification_id: Mapped[int] = mapped_column( + Integer, primary_key=True, nullable=False + ) + experiment_id: Mapped[int] = mapped_column( + Integer, ForeignKey("experiments.experiment_id"), nullable=False + ) + user_id: Mapped[int] = mapped_column( + Integer, ForeignKey("users.user_id"), nullable=False + ) + workspace_id: Mapped[int] = mapped_column( + Integer, ForeignKey("workspace.workspace_id"), nullable=False + ) + notification_type: Mapped[EventType] = mapped_column( + Enum(EventType), nullable=False + ) + notification_value: Mapped[int] = mapped_column(Integer, nullable=False) + is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) + + def to_dict(self) -> dict: + """ + Convert the model to a dictionary. + """ + return { + "notification_id": self.notification_id, + "experiment_id": self.experiment_id, + "user_id": self.user_id, + "notification_type": self.notification_type, + "notification_value": self.notification_value, + "is_active": self.is_active, + } + + +# --- ORM functions --- + + +# ---- Notifications functions ---- +async def save_notifications_to_db( + experiment_id: int, + user_id: int, + workspace_id: int, + notifications: Notifications, + asession: AsyncSession, +) -> list[NotificationsDB]: + """ + Save notifications to the database + """ + notification_records = [] + + if notifications.onTrialCompletion: + notification_row = NotificationsDB( + experiment_id=experiment_id, + user_id=user_id, + workspace_id=workspace_id, + notification_type=EventType.TRIALS_COMPLETED, + notification_value=notifications.numberOfTrials, + is_active=True, + ) + notification_records.append(notification_row) + + if notifications.onDaysElapsed: + notification_row = NotificationsDB( + experiment_id=experiment_id, + user_id=user_id, + workspace_id=workspace_id, + notification_type=EventType.DAYS_ELAPSED, + notification_value=notifications.daysElapsed, + is_active=True, + ) + notification_records.append(notification_row) + + if notifications.onPercentBetter: + notification_row = NotificationsDB( + experiment_id=experiment_id, + user_id=user_id, + workspace_id=workspace_id, + notification_type=EventType.PERCENTAGE_BETTER, + notification_value=notifications.percentBetterThreshold, + is_active=True, + ) + notification_records.append(notification_row) + + asession.add_all(notification_records) + await asession.commit() + + return notification_records + + +async def get_notifications_from_db( + experiment_id: int, user_id: int, workspace_id: int, asession: AsyncSession +) -> Sequence[NotificationsDB]: + """ + Get notifications from the database + """ + statement = ( + select(NotificationsDB) + .where(NotificationsDB.experiment_id == experiment_id) + .where(NotificationsDB.user_id == user_id) + .where(NotificationsDB.workspace_id == workspace_id) + ) + + return (await asession.execute(statement)).scalars().all() + + +# --- Experiment functions --- +async def save_experiment_to_db( + experiment: Experiment, + user_id: int, + workspace_id: int, + asession: AsyncSession, +) -> ExperimentDB: + """ + Save an experiment to the database. + """ + len_contexts = len(experiment.contexts) if experiment.contexts else 1 + contexts = [] + + arms = [ + ArmDB( + workspace_id=workspace_id, + # description + name=arm.name, + description=arm.description, + n_outcomes=0, + # prior variables + mu_init=arm.mu_init, + sigma_init=arm.sigma_init, + mu=[arm.mu_init] * len_contexts, + covariance=( + (np.identity(len_contexts) * arm.sigma_init**2).tolist() + if arm.sigma_init + else [[None]] + ), + alpha_init=arm.alpha_init, + beta_init=arm.beta_init, + alpha=arm.alpha_init, + beta=arm.beta_init, + is_treatment_arm=arm.is_treatment_arm, + ) + for arm in experiment.arms + ] + if experiment.contexts and len_contexts > 0: + contexts = [ + ContextDB( + workspace_id=workspace_id, + name=context.name, + description=context.description, + value_type=context.value_type, + ) + for context in experiment.contexts + ] + + experiment_db = ExperimentDB( + user_id=user_id, + workspace_id=workspace_id, + # description + name=experiment.name, + description=experiment.description, + is_active=experiment.is_active, + # assignments config + sticky_assignment=experiment.sticky_assignment, + auto_fail=experiment.auto_fail, + auto_fail_value=experiment.auto_fail_value, + auto_fail_unit=experiment.auto_fail_unit, + # experiment config + exp_type=experiment.exp_type, + prior_type=experiment.prior_type, + reward_type=experiment.reward_type, + # datetime + created_datetime_utc=datetime.now(timezone.utc), + n_trials=0, + # relationships + arms=arms, + contexts=contexts, + ) + + asession.add(experiment_db) + await asession.commit() + await asession.refresh(experiment_db) + + return experiment_db + + +async def get_all_experiments_from_db( + workspace_id: int, asession: AsyncSession +) -> Sequence[ExperimentDB]: + """ + Get all experiments for a given workspace. + """ + statement = ( + select(ExperimentDB) + .where(ExperimentDB.workspace_id == workspace_id) + .order_by(ExperimentDB.created_datetime_utc.desc()) + ) + return (await asession.execute(statement)).unique().scalars().all() + + +async def get_all_experiment_types_from_db( + workspace_id: int, experiment_type: str, asession: AsyncSession +) -> Sequence[ExperimentDB]: + """ + Get all experiments for a given workspace. + """ + statement = ( + select(ExperimentDB) + .where(ExperimentDB.workspace_id == workspace_id) + .where(ExperimentDB.exp_type == experiment_type) + .order_by(ExperimentDB.created_datetime_utc.desc()) + ) + return (await asession.execute(statement)).unique().scalars().all() + + +async def get_experiment_by_id_from_db( + workspace_id: int, experiment_id: int, asession: AsyncSession +) -> ExperimentDB | None: + """ + Get all experiments for a given workspace. + """ + statement = ( + select(ExperimentDB) + .where(ExperimentDB.workspace_id == workspace_id) + .where(ExperimentDB.experiment_id == experiment_id) + ) + return (await asession.execute(statement)).unique().scalars().one_or_none() + + +async def delete_experiment_by_id_from_db( + workspace_id: int, experiment_id: int, asession: AsyncSession +) -> None: + """ + Delete an experiment by ID for a given workspace. + """ + await asession.execute( + delete(NotificationsDB) + .where(NotificationsDB.workspace_id == workspace_id) + .where(NotificationsDB.experiment_id == experiment_id) + ) + + await asession.execute( + delete(ContextDB) + .where(ContextDB.workspace_id == workspace_id) + .where(ContextDB.experiment_id == experiment_id) + ) + + await asession.execute( + delete(ClientDB) + .where(ClientDB.workspace_id == workspace_id) + .where(ClientDB.experiment_id == experiment_id) + ) + + await asession.execute( + delete(DrawDB) + .where(DrawDB.workspace_id == workspace_id) + .where(DrawDB.experiment_id == experiment_id) + ) + + await asession.execute( + delete(ArmDB) + .where(ArmDB.workspace_id == workspace_id) + .where(ArmDB.experiment_id == experiment_id) + ) + + await asession.execute( + delete(ExperimentDB) + .where(ExperimentDB.workspace_id == workspace_id) + .where(ExperimentDB.experiment_id == experiment_id) + ) + + await asession.commit() + return None + + +# Draw functions +async def get_draw_by_id(draw_id: str, asession: AsyncSession) -> DrawDB | None: + """ + Get a draw by its ID, which should be unique across the system. + """ + statement = select(DrawDB).where(DrawDB.draw_id == draw_id) + result = await asession.execute(statement) + + return result.unique().scalar_one_or_none() + + +async def save_draw_to_db( + draw_id: str, + arm_id: int, + experiment_id: int, + workspace_id: int, + client_id: str | None, + context: list[float] | None, + asession: AsyncSession, +) -> DrawDB: + """ + Save a draw to the database. + """ + draw = DrawDB( + draw_id=draw_id, + arm_id=arm_id, + experiment_id=experiment_id, + workspace_id=workspace_id, + client_id=client_id, + draw_datetime_utc=datetime.now(timezone.utc), + context_val=context, + ) + asession.add(draw) + await asession.commit() + await asession.refresh(draw) + + return draw + + +async def save_observation_to_db( + draw: DrawDB, + reward: float, + observation_type: ObservationType, + asession: AsyncSession, +) -> DrawDB: + """ + Save an observation to the database. + """ + draw.observed_datetime_utc = datetime.now(timezone.utc) + draw.observation_type = observation_type + draw.reward = reward + + await asession.commit() + await asession.refresh(draw) + + return draw + + +async def get_draws_by_experiment_id( + experiment_id: int, asession: AsyncSession +) -> Sequence[DrawDB]: + """ + Get all draws for a given experiment ID. + """ + statement = ( + select(DrawDB) + .where(DrawDB.experiment_id == experiment_id) + .order_by(DrawDB.draw_datetime_utc.desc()) + ) + return (await asession.execute(statement)).unique().scalars().all() + + +async def get_draws_with_rewards_by_experiment_id( + experiment_id: int, asession: AsyncSession +) -> Sequence[DrawDB]: + """ + Get all draws with rewards for a given experiment ID. + """ + statement = ( + select(DrawDB) + .where(DrawDB.experiment_id == experiment_id) + .where(DrawDB.reward.is_not(None)) + .order_by(DrawDB.draw_datetime_utc.desc()) + ) + return (await asession.execute(statement)).unique().scalars().all() diff --git a/backend/app/experiments/routers.py b/backend/app/experiments/routers.py new file mode 100644 index 0000000..58431ad --- /dev/null +++ b/backend/app/experiments/routers.py @@ -0,0 +1,489 @@ +from typing import Annotated, Optional +from uuid import uuid4 + +import numpy as np +from fastapi import APIRouter, Depends +from fastapi.exceptions import HTTPException +from sqlalchemy.ext.asyncio import AsyncSession + +from ..auth.dependencies import ( + authenticate_workspace_key, + get_verified_user, + require_admin_role, +) +from ..database import get_async_session +from ..users.models import UserDB +from ..utils import setup_logger +from ..workspaces.models import ( + WorkspaceDB, + get_user_default_workspace, +) +from .dependencies import ( + experiments_db_to_schema, + format_rewards_for_arm_update, + save_updated_data, + update_arm_based_on_outcome, + validate_experiment_and_draw, +) +from .models import ( + delete_experiment_by_id_from_db, + get_all_experiment_types_from_db, + get_all_experiments_from_db, + get_draw_by_id, + get_draws_by_experiment_id, + get_experiment_by_id_from_db, + save_draw_to_db, + save_experiment_to_db, + save_notifications_to_db, +) +from .sampling_utils import choose_arm +from .schemas import ( + ArmResponse, + ContextInput, + ContextType, + DrawResponse, + Experiment, + ExperimentSample, + ExperimentsEnum, + Outcome, +) + +router = APIRouter(prefix="/experiment", tags=["Experiments"]) + +logger = setup_logger(__name__) + + +# --- POST experiments routers --- +@router.post("/", response_model=ExperimentSample) +async def create_experiment( + experiment: Experiment, + user_db: Annotated[UserDB, Depends(require_admin_role)], + asession: AsyncSession = Depends(get_async_session), +) -> ExperimentSample: + """ + Create a new experiment in the current user's workspace. + """ + workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) + + if workspace_db is None: + raise HTTPException( + status_code=404, + detail="Workspace not found. Please create a workspace first.", + ) + + experiment_db = await save_experiment_to_db( + experiment=experiment, + workspace_id=workspace_db.workspace_id, + user_id=user_db.user_id, + asession=asession, + ) + notifications = await save_notifications_to_db( + experiment_id=experiment_db.experiment_id, + user_id=user_db.user_id, + workspace_id=workspace_db.workspace_id, + notifications=experiment.notifications, + asession=asession, + ) + + experiment_dict = experiment_db.to_dict() + experiment_dict["notifications"] = [n.to_dict() for n in notifications] + return ExperimentSample.model_validate(experiment_dict) + + +# -- GET experiment routers --- +@router.get("/", response_model=list[ExperimentSample]) +async def get_all_experiments( + user_db: Annotated[UserDB, Depends(get_verified_user)], + asession: AsyncSession = Depends(get_async_session), +) -> list[ExperimentSample]: + """ + Retrieve all experiments for the current user's workspace. + """ + workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) + + if workspace_db is None: + raise HTTPException( + status_code=404, + detail="Workspace not found. Please create a workspace first.", + ) + + experiments = await get_all_experiments_from_db( + workspace_id=workspace_db.workspace_id, + asession=asession, + ) + + all_experiments = await experiments_db_to_schema( + experiments_db=list(experiments), + asession=asession, + ) + return all_experiments + + +@router.get("/type/{experiment_type}", response_model=list[ExperimentSample]) +async def get_all_experiments_by_type( + experiment_type: ExperimentsEnum, + user_db: Annotated[UserDB, Depends(get_verified_user)], + asession: AsyncSession = Depends(get_async_session), +) -> list[ExperimentSample]: + """ + Retrieve all experiments for the current user's workspace. + """ + workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) + + if workspace_db is None: + raise HTTPException( + status_code=404, + detail="Workspace not found. Please create a workspace first.", + ) + + experiments = await get_all_experiment_types_from_db( + workspace_id=workspace_db.workspace_id, + experiment_type=experiment_type.value, + asession=asession, + ) + + all_experiments = await experiments_db_to_schema( + experiments_db=list(experiments), + asession=asession, + ) + return all_experiments + + +@router.get("/id/{experiment_id}", response_model=ExperimentSample) +async def get_experiment_by_id( + experiment_id: int, + user_db: Annotated[UserDB, Depends(get_verified_user)], + asession: AsyncSession = Depends(get_async_session), +) -> ExperimentSample: + """ + Retrieve a specific experiment by ID for the current user's workspace. + """ + workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) + + if workspace_db is None: + raise HTTPException( + status_code=404, + detail="Workspace not found. Please create a workspace first.", + ) + + experiment = await get_experiment_by_id_from_db( + workspace_id=workspace_db.workspace_id, + experiment_id=experiment_id, + asession=asession, + ) + + if not experiment: + raise HTTPException( + status_code=404, + detail="Experiment not found.", + ) + + experiment_dict = await experiments_db_to_schema( + experiments_db=[experiment], + asession=asession, + ) + + return experiment_dict[0] + + +# -- DELETE experiment routers --- +@router.delete("/type/{experiment_type}", response_model=dict[str, str]) +async def delete_experiment_by_type( + experiment_type: ExperimentsEnum, + user_db: Annotated[UserDB, Depends(get_verified_user)], + asession: AsyncSession = Depends(get_async_session), +) -> dict[str, str]: + """ + Retrieve a specific experiment by ID for the current user's workspace. + """ + try: + workspace_db = await get_user_default_workspace( + asession=asession, user_db=user_db + ) + + if workspace_db is None: + raise HTTPException( + status_code=404, + detail="Workspace not found. Please create a workspace first.", + ) + + experiments = await get_all_experiment_types_from_db( + workspace_id=workspace_db.workspace_id, + experiment_type=experiment_type.value, + asession=asession, + ) + + if len(experiments) == 0: + raise HTTPException( + status_code=404, + detail="No experiments found.", + ) + + for exp in experiments: + await delete_experiment_by_id_from_db( + workspace_id=workspace_db.workspace_id, + experiment_id=exp.experiment_id, + asession=asession, + ) + + return { + "message": f"Experiments of type {experiment_type} deleted successfully." + } + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error: {str(e)}", + ) from e + + +@router.delete("/id/{experiment_id}", response_model=dict[str, str]) +async def delete_experiment_by_id( + experiment_id: int, + user_db: Annotated[UserDB, Depends(get_verified_user)], + asession: AsyncSession = Depends(get_async_session), +) -> dict[str, str]: + """ + Retrieve a specific experiment by ID for the current user's workspace. + """ + try: + workspace_db = await get_user_default_workspace( + asession=asession, user_db=user_db + ) + + if workspace_db is None: + raise HTTPException( + status_code=404, + detail="Workspace not found. Please create a workspace first.", + ) + + experiment = await get_experiment_by_id_from_db( + workspace_id=workspace_db.workspace_id, + experiment_id=experiment_id, + asession=asession, + ) + + if not experiment: + raise HTTPException( + status_code=404, + detail="Experiment not found.", + ) + + await delete_experiment_by_id_from_db( + workspace_id=workspace_db.workspace_id, + experiment_id=experiment_id, + asession=asession, + ) + + return {"message": f"Experiment with id {experiment_id} deleted successfully."} + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error: {str(e)}", + ) from e + + +# --- Draw and update arms --- +@router.put("/{experiment_id}/draw", response_model=DrawResponse) +async def draw_experiment_arm( + experiment_id: int, + contexts: Optional[list[ContextInput]] = None, + draw_id: Optional[str] = None, + workspace_db: WorkspaceDB = Depends(authenticate_workspace_key), + asession: AsyncSession = Depends(get_async_session), +) -> DrawResponse: + """ + Draw an arm from the specified experiment. + """ + workspace_id = workspace_db.workspace_id + + experiment = await get_experiment_by_id_from_db( + workspace_id=workspace_id, experiment_id=experiment_id, asession=asession + ) + if experiment is None: + raise HTTPException( + status_code=404, detail=f"Experiment with id {experiment_id} not found" + ) + + # Check contexts + if (experiment.exp_type == ExperimentsEnum.CMAB.value) and (not contexts): + raise HTTPException( + status_code=400, detail="Context is required for CMAB experiments." + ) + elif (experiment.exp_type == ExperimentsEnum.CMAB.value) and contexts: + context_length = 0 if not experiment.contexts else len(experiment.contexts) + if len(contexts) != context_length: + raise HTTPException( + status_code=400, + detail=( + f"Expected {context_length} contexts" f" but got {len(contexts)}." + ), + ) + + # Check for existing draws + if draw_id is None: + draw_id = str(uuid4()) + + existing_draw = await get_draw_by_id(draw_id=draw_id, asession=asession) + if existing_draw: + raise HTTPException( + status_code=400, detail=f"Draw with id {draw_id} already exists." + ) + + # -- Perform the draw --- + experiment_data = ExperimentSample.model_validate(experiment.to_dict()) + + # Validate contexts input + if contexts: + sorted_contexts = list(sorted(contexts, key=lambda x: x.context_id)) + try: + exp_contexts = experiment_data.contexts or [] + sorted_exp_contexts = ( + sorted(exp_contexts, key=lambda x: x.context_id) if exp_contexts else [] + ) + if [c1.context_id for c1 in sorted_contexts] != [ + c2.context_id for c2 in sorted_exp_contexts + ]: + raise ValueError( + "Provided contexts do not match the experiment's expected contexts." + ) + for c_input, c_exp in zip( + sorted_contexts, + sorted_exp_contexts, + ): + if c_exp.value_type == ContextType.BINARY.value: + Outcome(c_input.context_value) + except ValueError as e: + raise HTTPException( + status_code=400, + detail=f"Invalid context value: {e}", + ) from e + + # Choose arm + chosen_arm = choose_arm( + experiment=experiment_data, + context=[c.context_value for c in sorted_contexts] if contexts else None, + ) + chosen_arm_id = experiment.arms[chosen_arm].arm_id + + try: + draw = await save_draw_to_db( + draw_id=draw_id, + arm_id=chosen_arm_id, + experiment_id=experiment_id, + workspace_id=workspace_id, + client_id=None, # TODO: Update for sticky assignment + context=[c.context_value for c in sorted_contexts] if contexts else None, + asession=asession, + ) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error saving draw: {str(e)}", + ) from e + + draw_response_data = { + "draw_id": draw_id, + "draw_datetime_utc": str(draw.draw_datetime_utc), + "arm": experiment_data.arms[chosen_arm], + "context_val": draw.context_val, + } + return DrawResponse.model_validate(draw_response_data) + + +@router.put("/{experiment_id}/{draw_id}/{reward}", response_model=ArmResponse) +async def update_experiment_arm( + experiment_id: int, + draw_id: str, + reward: float, + workspace_db: WorkspaceDB = Depends(authenticate_workspace_key), + asession: AsyncSession = Depends(get_async_session), +) -> ArmResponse: + """ + Update the arm with the given reward. + """ + + experiment, draw = await validate_experiment_and_draw( + experiment_id=experiment_id, + draw_id=draw_id, + workspace_id=workspace_db.workspace_id, + asession=asession, + ) + + # Get rewards + chosen_arm_index = int( + np.argwhere(np.array([arm.arm_id for arm in experiment.arms]) == draw.arm_id)[ + 0 + ][0], + ) + rewards_list, context_list, treatments_list = await format_rewards_for_arm_update( + experiment=experiment, + chosen_arm_id=draw.arm_id, + reward=reward, + context_val=draw.context_val, + asession=asession, + ) + + # Update the arm with the given reward + try: + await update_arm_based_on_outcome( + experiment=experiment, + draw=draw, + rewards=rewards_list, + contexts=context_list, + treatments=treatments_list, + ) + + observation_type = draw.observation_type + + await save_updated_data( + arm=experiment.arms[chosen_arm_index], + draw=draw, + reward=reward, + observation_type=observation_type, + asession=asession, + ) + return ArmResponse.model_validate(experiment.arms[chosen_arm_index]) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error updating arm: {str(e)}", + ) from e + + +@router.get("/{experiment_id}/rewards", response_model=list[DrawResponse]) +async def get_rewards( + experiment_id: int, + workspace_db: WorkspaceDB = Depends(authenticate_workspace_key), + asession: AsyncSession = Depends(get_async_session), +) -> list[DrawResponse]: + """ + Retrieve all rewards for the specified experiment. + """ + experiment = await get_experiment_by_id_from_db( + workspace_id=workspace_db.workspace_id, + experiment_id=experiment_id, + asession=asession, + ) + + if not experiment: + raise HTTPException( + status_code=404, detail=f"Experiment with id {experiment_id} not found" + ) + + draws = await get_draws_by_experiment_id( + experiment_id=experiment_id, asession=asession + ) + + return [ + DrawResponse.model_validate( + { + "draw_id": draw.draw_id, + "draw_datetime_utc": str(draw.draw_datetime_utc), + "observed_datetime_utc": str(draw.observed_datetime_utc), + "arm": [arm for arm in experiment.arms if arm.arm_id == draw.arm_id][0], + "reward": draw.reward, + "context_val": draw.context_val, + } + ) + for draw in draws + ] diff --git a/backend/app/experiments/sampling_utils.py b/backend/app/experiments/sampling_utils.py new file mode 100644 index 0000000..ac9db72 --- /dev/null +++ b/backend/app/experiments/sampling_utils.py @@ -0,0 +1,326 @@ +from typing import Any, Optional, Union + +import numpy as np +from numpy.random import beta +from scipy.optimize import minimize + +from .schemas import ( + ArmPriors, + ContextLinkFunctions, + ExperimentSample, + ExperimentsEnum, + Outcome, + RewardLikelihood, +) + + +# ------------- Utilities for sampling and updating arms ---------------- +# --- Sampling functions for Thompson Sampling --- +def _sample_beta_binomial(alphas: np.ndarray, betas: np.ndarray) -> int: + """ + Thompson Sampling with Beta-Binomial distribution. + + Parameters + ---------- + alphas : alpha parameter of Beta distribution for each arm + betas : beta parameter of Beta distribution for each arm + """ + samples = beta(alphas, betas) + return int(samples.argmax()) + + +def _sample_normal( + mus: list[np.ndarray], + covariances: list[np.ndarray], + context: np.ndarray, + link_function: ContextLinkFunctions, +) -> int: + """ + Thompson Sampling with normal prior. + + Parameters + ---------- + mus: mean of Normal distribution for each arm + covariances: covariance matrix of Normal distribution for each arm + context: context vector + link_function: link function for the context + """ + samples = np.array( + [ + np.random.multivariate_normal(mean=mu, cov=cov) + for mu, cov in zip(mus, covariances) + ] + ).reshape(-1, len(context)) + probs = link_function(samples @ context) + return int(probs.argmax()) + + +# --- Arm update functions --- +def _update_arm_beta_binomial( + alpha: float, beta: float, reward: Outcome +) -> tuple[float, float]: + """ + Update the alpha and beta parameters of the Beta distribution. + + Parameters + ---------- + alpha : int + The alpha parameter of the Beta distribution. + beta : int + The beta parameter of the Beta distribution. + reward : Outcome + The reward of the arm. + """ + if reward == Outcome.SUCCESS: + + return alpha + 1, beta + else: + return alpha, beta + 1 + + +def _update_arm_normal( + current_mu: np.ndarray, + current_covariance: np.ndarray, + reward: float, + llhood_sigma: float, + context: np.ndarray, +) -> tuple[float, np.ndarray]: + """ + Update the mean and standard deviation of the Normal distribution. + + Parameters + ---------- + current_mu : The mean of the Normal distribution. + current_covariance : The covariance of the Normal distribution. + reward : The reward of the arm. + llhood_sigma : The standard deviation of the likelihood. + context : The context vector. + """ + # Likelihood covariance matrix inverse + llhood_covariance_inv = np.eye(len(current_mu)) / llhood_sigma**2 + llhood_covariance_inv *= context.T @ context + + # Prior covariance matrix inverse + prior_covariance_inv = np.linalg.inv(current_covariance) + + # New covariance + new_covariance = np.linalg.inv(prior_covariance_inv + llhood_covariance_inv) + + # New mean + llhood_term: Union[np.ndarray, float] = reward / llhood_sigma**2 + if context is not None: + llhood_term = (context * llhood_term).squeeze() + + new_mu = new_covariance @ ((prior_covariance_inv @ current_mu) + llhood_term) + return new_mu.tolist(), new_covariance.tolist() + + +def _update_arm_laplace( + current_mu: np.ndarray, + current_covariance: np.ndarray, + reward: np.ndarray, + context: np.ndarray, + link_function: ContextLinkFunctions, + reward_likelihood: RewardLikelihood, + prior_type: ArmPriors, +) -> tuple[np.ndarray, np.ndarray]: + """ + Update the mean and covariance using the Laplace approximation. + + Parameters + ---------- + current_mu : The mean of the normal distribution. + current_covariance : The covariance matrix of the normal distribution. + reward : The list of rewards for the arm. + context : The list of contexts for the arm. + link_function : The link function for parameters to rewards. + reward_likelihood : The likelihood function of the reward. + prior_type : The prior type of the arm. + """ + print(current_mu.shape, current_covariance.shape, reward.shape, context.shape) + + def objective(theta: np.ndarray) -> float: + """ + Objective function for the Laplace approximation. + + Parameters + ---------- + theta : The parameters of the arm. + """ + # Log prior + log_prior = prior_type(theta, mu=current_mu, covariance=current_covariance) + + # Log likelihood + log_likelihood = reward_likelihood(reward, link_function(context @ theta)) + + return -log_prior - log_likelihood + + result = minimize( + objective, x0=np.zeros_like(current_mu), method="L-BFGS-B", hess="2-point" + ) + new_mu = result.x + covariance = result.hess_inv.todense() # type: ignore + + new_covariance = 0.5 * (covariance + covariance.T) + return new_mu.tolist(), new_covariance.tolist() + + +# ------------- Import functions ---------------- +# --- Choose arm function --- +def choose_arm( + experiment: ExperimentSample, context: Optional[Union[list, np.ndarray, None]] +) -> int: + """ + Choose arm based on posterior using Thompson Sampling. + + Parameters + ---------- + experiment: The experiment data containing priors and rewards for each arm. + context: Optional context vector for the experiment. + """ + # Choose arms with equal probability for Bayesian A/B tests + if experiment.exp_type == ExperimentsEnum.BAYESAB: + index = np.random.choice(len(experiment.arms), size=1) + return int(index[0]) + else: + if experiment.prior_type == ArmPriors.BETA: + if experiment.reward_type != RewardLikelihood.BERNOULLI: + raise ValueError("Beta prior is only supported for Bernoulli rewards.") + alphas = np.array([arm.alpha for arm in experiment.arms]) + betas = np.array([arm.beta for arm in experiment.arms]) + + return _sample_beta_binomial(alphas=alphas, betas=betas) + + elif experiment.prior_type == ArmPriors.NORMAL: + mus = [np.array(arm.mu) for arm in experiment.arms] + covariances = [np.array(arm.covariance) for arm in experiment.arms] + + context_array = ( + np.ones_like(mus[0]) if context is None else np.array(context) + ) + + return _sample_normal( + mus=mus, + covariances=covariances, + context=context_array, + link_function=( + ContextLinkFunctions.NONE + if experiment.reward_type == RewardLikelihood.NORMAL + else ContextLinkFunctions.LOGISTIC + ), + ) + + +# --- Update arm parameters --- +def update_arm( + experiment: ExperimentSample, + rewards: list[float], + arm_to_update: Optional[int] = None, + context: Optional[Union[list, np.ndarray, None]] = None, + treatments: Optional[list[float]] = None, +) -> Any: + """ + Update the arm parameters based on the experiment type and reward. + + Parameters + ---------- + experiment: The experiment data containing arms, prior type and reward + type information. + rewards: The rewards received from the arm. + context: The context vector for the arm. + treatments: The treatments applied to the arm, for a Bayesian A/B test. + """ + + # NB: For Bayesian AB tests, we assume that the update runs + # AFTER all rewards have been observed. + # We hijack the Laplace approximation function to update the + # model parameters as follows: + # 1. current_mu -> [treatment_mu, control_mu, bias_mu = 0] + # 2. current_covariance -> [treatment_sigma, control_sigma, bias_sigma = 1] + # 3. context -> [is_treatment_arm, is_control_arm, 1] + if experiment.exp_type == ExperimentsEnum.BAYESAB: + + assert treatments, "Treatments must be provided for Bayesian A/B tests." + assert [ + arm.mu for arm in experiment.arms + ], "Arms must have mu parameters for Bayesian A/B tests." + assert [ + arm.covariance for arm in experiment.arms + ], "Arms must have covariance parameters for Bayesian A/B tests." + + mus = np.array([arm.mu[0] for arm in experiment.arms if arm.mu] + [0.0]) + covariances = np.diag( + [ + np.array(arm.covariance).ravel()[0] + for arm in experiment.arms + if arm.covariance + ] + + [1.0] + ) + context = np.zeros((len(rewards), 3)) if not context else np.array(context) + print(rewards, treatments) + context[:, 0] = np.array(treatments) + context[:, 1] = 1.0 - np.array(treatments) + context[:, 2] = 1.0 + + new_mus, new_covariances = _update_arm_laplace( + current_mu=mus, + current_covariance=covariances, + reward=np.array(rewards), + context=context, + link_function=( + ContextLinkFunctions.NONE + if experiment.reward_type == RewardLikelihood.NORMAL + else ContextLinkFunctions.LOGISTIC + ), + reward_likelihood=experiment.reward_type, + prior_type=experiment.prior_type, + ) + + treatment_mu, control_mu, _ = new_mus + treatment_sigma, control_sigma, _ = np.diag(new_covariances) + return [treatment_mu, control_mu], [ + [[float(treatment_sigma)]], + [[float(control_sigma)]], + ] + else: + # Update for MABs and CMABs + assert arm_to_update is not None, "Arm to update must be provided." + arm = experiment.arms[arm_to_update] + + # Beta-binomial priors + if experiment.prior_type == ArmPriors.BETA: + assert arm.alpha and arm.beta, "Arm must have alpha and beta parameters." + return _update_arm_beta_binomial( + alpha=arm.alpha, beta=arm.beta, reward=Outcome(rewards[0]) + ) + + # Normal priors + elif experiment.prior_type == ArmPriors.NORMAL: + assert ( + arm.mu and arm.covariance + ), "Arm must have mu and covariance parameters." + if context is None: + context = np.ones((1, len(arm.mu))) + # Normal likelihood + if experiment.reward_type == RewardLikelihood.NORMAL: + return _update_arm_normal( + current_mu=np.array(arm.mu), + current_covariance=np.array(arm.covariance), + reward=rewards[0], + llhood_sigma=1.0, # TODO: Assuming a fixed likelihood sigma + context=np.array(context[0]), + ) + # TODO: only supports Bernoulli likelihood + else: + return _update_arm_laplace( + current_mu=np.array(arm.mu), + current_covariance=np.array(arm.covariance), + reward=np.array(rewards), + context=np.array(context), + link_function=ContextLinkFunctions.LOGISTIC, + reward_likelihood=experiment.reward_type, + prior_type=experiment.prior_type, + ) + else: + raise ValueError("Unsupported prior type for arm update.") diff --git a/backend/app/experiments/schemas.py b/backend/app/experiments/schemas.py new file mode 100644 index 0000000..ffd7e6e --- /dev/null +++ b/backend/app/experiments/schemas.py @@ -0,0 +1,548 @@ +from enum import Enum, StrEnum +from typing import Any, List, Optional, Self, Union + +import numpy as np +from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic.types import NonNegativeInt + + +# --- Enums --- +class ExperimentsEnum(StrEnum): + """ + Enum for the experiment types. + """ + + MAB = "mab" + CMAB = "cmab" + BAYESAB = "bayes_ab" + + +class EventType(StrEnum): + """Types of events that can trigger a notification""" + + DAYS_ELAPSED = "days_elapsed" + TRIALS_COMPLETED = "trials_completed" + PERCENTAGE_BETTER = "percentage_better" + + +class ObservationType(StrEnum): + """Types of observations that can be made""" + + USER = "user" # Generated by the user + AUTO = "auto" # Generated by the system + + +class AutoFailUnitType(StrEnum): + """Types of units for auto fail""" + + DAYS = "days" + HOURS = "hours" + + +class Outcome(float, Enum): + """ + Enum for the outcome of a trial. + """ + + SUCCESS = 1 + FAILURE = 0 + + +class ArmPriors(StrEnum): + """ + Enum for the prior distribution of the arm. + """ + + BETA = "beta" + NORMAL = "normal" + + def __call__(self, theta: np.ndarray, **kwargs: Any) -> np.ndarray: + """ + Return the log pdf of the input param. + """ + if self == ArmPriors.BETA: + alpha = kwargs.get("alpha", np.ones_like(theta)) + beta = kwargs.get("beta", np.ones_like(theta)) + return (alpha - 1) * np.log(theta) + (beta - 1) * np.log(1 - theta) + + elif self == ArmPriors.NORMAL: + mu = kwargs.get("mu", np.zeros_like(theta)) + covariance = kwargs.get("covariance", np.diag(np.ones_like(theta))) + inv_cov = np.linalg.inv(covariance) + x = theta - mu + return -0.5 * x @ inv_cov @ x + + +class RewardLikelihood(StrEnum): + """ + Enum for the likelihood distribution of the reward. + """ + + BERNOULLI = "binary" + NORMAL = "real-valued" + + def __call__(self, reward: np.ndarray, probs: np.ndarray) -> np.ndarray: + """ + Calculate the log likelihood of the reward. + + Parameters + ---------- + reward : The reward. + probs : The probability of the reward. + """ + if self == RewardLikelihood.NORMAL: + return -0.5 * np.sum((reward - probs) ** 2) + elif self == RewardLikelihood.BERNOULLI: + return np.sum(reward * np.log(probs) + (1 - reward) * np.log(1 - probs)) + + +class ContextType(StrEnum): + """ + Enum for the type of context. + """ + + BINARY = "binary" + REAL_VALUED = "real-valued" + + +class ContextLinkFunctions(StrEnum): + """ + Enum for the link function of the arm params and context. + """ + + NONE = "none" + LOGISTIC = "logistic" + + def __call__(self, x: np.ndarray) -> np.ndarray: + """ + Apply the link function to the input param. + + Parameters + ---------- + x : The input param. + """ + if self == ContextLinkFunctions.NONE: + return x + elif self == ContextLinkFunctions.LOGISTIC: + return 1.0 / (1.0 + np.exp(-x)) + + +# --- Schemas --- +# Notifications schema +class Notifications(BaseModel): + """ + Pydantic model for a notifications. + """ + + onTrialCompletion: bool = False + numberOfTrials: NonNegativeInt | None + onDaysElapsed: bool = False + daysElapsed: NonNegativeInt | None + onPercentBetter: bool = False + percentBetterThreshold: NonNegativeInt | None + + @model_validator(mode="after") + def validate_has_assocatiated_value(self) -> Self: + """ + Validate that the required corresponding fields have been set. + """ + if self.onTrialCompletion and ( + not self.numberOfTrials or self.numberOfTrials == 0 + ): + raise ValueError( + "numberOfTrials is required when onTrialCompletion is True" + ) + if self.onDaysElapsed and (not self.daysElapsed or self.daysElapsed == 0): + raise ValueError("daysElapsed is required when onDaysElapsed is True") + if self.onPercentBetter and ( + not self.percentBetterThreshold or self.percentBetterThreshold == 0 + ): + raise ValueError( + "percentBetterThreshold is required when onPercentBetter is True" + ) + + return self + + +class NotificationsResponse(BaseModel): + """ + Pydantic model for a response for notifications + """ + + model_config = ConfigDict(from_attributes=True) + + notification_id: int + notification_type: EventType + notification_value: NonNegativeInt + is_active: bool + + +# Arms +class Arm(BaseModel): + """ + Pydantic model for an arm. + """ + + model_config = ConfigDict(from_attributes=True) + + # Description + name: str = Field( + max_length=150, + examples=["Arm 1"], + ) + description: str = Field( + max_length=500, + examples=["This is a description of the arm."], + ) + + # Prior variables + alpha_init: Optional[float] = Field( + default=None, examples=[None, 1.0], description="Alpha parameter for Beta prior" + ) + beta_init: Optional[float] = Field( + default=None, examples=[None, 1.0], description="Beta parameter for Beta prior" + ) + mu_init: Optional[float] = Field( + default=None, + examples=[None, 0.0], + description="Mean parameter for Normal prior", + ) + sigma_init: Optional[float] = Field( + default=None, + examples=[None, 1.0], + description="Standard deviation parameter for Normal prior", + ) + is_treatment_arm: Optional[bool] = Field( + default=True, + description="Whether the arm is a treatment arm or not", + ) + + @model_validator(mode="after") + def check_values(self) -> Self: + """ + Check if the values are unique. + """ + alpha = self.alpha_init + beta = self.beta_init + sigma = self.sigma_init + if alpha is not None and alpha <= 0: + raise ValueError("Alpha must be greater than 0.") + if beta is not None and beta <= 0: + raise ValueError("Beta must be greater than 0.") + if sigma is not None and sigma <= 0: + raise ValueError("Sigma must be greater than 0.") + return self + + +class ArmResponse(Arm): + """ + Pydantic model for an response for arm creation + """ + + arm_id: int + experiment_id: int + n_outcomes: int + alpha: Optional[Union[float, None]] + beta: Optional[Union[float, None]] + mu: Optional[List[Union[float, None]]] + covariance: Optional[List[List[Union[float, None]]]] + model_config = ConfigDict( + from_attributes=True, + ) + + +# Contexts +class Context(BaseModel): + """ + Pydantic model for a binary-valued context of the experiment. + """ + + name: str = Field( + description="Name of the context", + examples=["Context 1"], + ) + description: str = Field( + description="Description of the context", + examples=["This is a description of the context."], + ) + value_type: ContextType = Field( + description="Type of value the context can take", default=ContextType.BINARY + ) + model_config = ConfigDict(from_attributes=True) + + +class ContextResponse(Context): + """ + Pydantic model for an response for context creation + """ + + context_id: int + model_config = ConfigDict(from_attributes=True) + + +class ContextInput(BaseModel): + """ + Pydantic model for a context input + """ + + context_id: int + context_value: float + model_config = ConfigDict(from_attributes=True) + + +# Client +class Client(BaseModel): + """ + Pydantic model for a client. + """ + + model_config = ConfigDict(from_attributes=True) + + client_id: str = Field( + description="Unique identifier for the client", + examples=["client_123"], + ) + + +class DrawResponse(BaseModel): + """ + Pydantic model for a response for draw creation + """ + + model_config = ConfigDict(from_attributes=True) + + draw_id: str = Field( + description="Unique identifier for the draw", + examples=["draw_123"], + ) + draw_datetime_utc: str = Field( + description="Timestamp of when the draw was made", + examples=["2023-10-01T12:00:00Z"], + ) + observed_datetime_utc: Optional[str] = Field( + description="Timestamp of when the reward was observed", + default=None, + ) + + # Draw info + reward: Optional[float] = Field( + description="Reward observed from the draw", + default=None, + ) + context_val: Optional[list[float]] = Field( + description="Context values associated with the draw", + default=None, + ) + arm: ArmResponse + client: Optional[Client] = None + + +# Experiments +class ExperimentBase(BaseModel): + """ + Pydantic base model for an experiment. + + Note: This is a base model and should not be used directly. + Use the `Experiment` model instead. + """ + + model_config = ConfigDict(from_attributes=True) + + # Description + name: str = Field( + max_length=150, + examples=["Experiment 1"], + ) + description: str = Field( + max_length=500, + examples=["This is a description of the experiment."], + ) + + is_active: bool = True + + # Assignments config + sticky_assignment: bool = Field( + description="Whether the arm assignment is sticky or not.", + default=False, + ) + + auto_fail: bool = Field( + description=( + "Whether the experiment should fail automatically after " + "a certain period if no outcome is registered." + ), + default=False, + ) + + auto_fail_value: Optional[int] = Field( + description="The time period after which the experiment should fail.", + default=None, + ) + + auto_fail_unit: Optional[AutoFailUnitType] = Field( + description="The time unit for the auto fail period.", + default=None, + ) + + # Experiment config + exp_type: ExperimentsEnum = Field( + description="The type of experiment.", + default=ExperimentsEnum.MAB, + ) + prior_type: ArmPriors = Field( + description="The type of prior distribution for the arms.", + default=ArmPriors.BETA, + ) + reward_type: RewardLikelihood = Field( + description="The type of reward we observe from the experiment.", + default=RewardLikelihood.BERNOULLI, + ) + + +class Experiment(ExperimentBase): + """ + Pydantic model for an experiment. + """ + + # Relationships + arms: list[Arm] + notifications: Notifications + contexts: Optional[list[Context]] + clients: Optional[list[Client]] + + @model_validator(mode="after") + def auto_fail_unit_and_value_set(self) -> Self: + """ + Validate that the auto fail unit and value are set if auto fail is True. + """ + if self.auto_fail: + if ( + not self.auto_fail_value + or not self.auto_fail_unit + or self.auto_fail_value <= 0 + ): + raise ValueError( + ( + "Auto fail is enabled. " + "Please provide both auto_fail_value and auto_fail_unit." + ) + ) + return self + + @model_validator(mode="after") + def check_num_arms(self) -> Self: + """ + Validate that the experiment has at least two arms. + """ + if len(self.arms) < 2: + raise ValueError("The experiment must have at least two arms.") + if self.exp_type == ExperimentsEnum.BAYESAB and len(self.arms) > 2: + raise ValueError("Bayes AB experiments can only have two arms.") + return self + + @model_validator(mode="after") + def check_arm_missing_params(self) -> Self: + """ + Check if the arm reward type is same as the experiment reward type. + """ + prior_type = self.prior_type + arms = self.arms + + prior_params = { + ArmPriors.BETA: ("alpha_init", "beta_init"), + ArmPriors.NORMAL: ("mu_init", "sigma_init"), + } + + for arm in arms: + arm_dict = arm.model_dump() + if prior_type in prior_params: + missing_params = [] + for param in prior_params[prior_type]: + if param not in arm_dict.keys(): + missing_params.append(param) + elif arm_dict[param] is None: + missing_params.append(param) + + if missing_params: + val = prior_type.value + raise ValueError(f"{val} prior needs {','.join(missing_params)}.") + return self + + @model_validator(mode="after") + def check_treatment_info(self) -> Self: + """ + Validate that the treatment arm information is set correctly. + """ + arms = self.arms + if self.exp_type == ExperimentsEnum.BAYESAB: + if not any(arm.is_treatment_arm for arm in arms): + raise ValueError("At least one arm must be a treatment arm.") + if all(arm.is_treatment_arm for arm in arms): + raise ValueError("At least one arm must be a control arm.") + return self + + @model_validator(mode="after") + def check_prior_reward_type_combo(self) -> Self: + """ + Validate that the prior and reward type combination is allowed. + """ + if self.prior_type == ArmPriors.BETA: + if not self.reward_type == RewardLikelihood.BERNOULLI: + raise ValueError( + "Beta prior can only be used with binary-valued rewards." + ) + if self.exp_type != ExperimentsEnum.MAB: + raise ValueError( + f"Experiments of type {self.exp_type} can only use Gaussian priors." + ) + + return self + + @model_validator(mode="after") + def check_contexts(self) -> Self: + """ + Validate that the contexts inputs are valid. + """ + if self.exp_type == "cmab" and not self.contexts: + raise ValueError("Contextual MAB experiments require at least one context.") + if self.exp_type != "cmab" and self.contexts: + raise ValueError( + "Contexts are only applicable for contextual MAB experiments." + ) + return self + + model_config = ConfigDict(from_attributes=True) + + +class ExperimentResponse(ExperimentBase): + """ + Pydantic model for a response for experiment creation + """ + + experiment_id: int + n_trials: int + last_trial_datetime_utc: Optional[str] = None + + arms: list[ArmResponse] + notifications: list[NotificationsResponse] + contexts: Optional[list[ContextResponse]] = None + clients: Optional[list[Client]] = None + + model_config = ConfigDict(from_attributes=True) + + +class ExperimentSample(ExperimentBase): + """ + Pydantic model for experiments for drawing and updating arms. + """ + + experiment_id: int + n_trials: int + last_trial_datetime_utc: Optional[str] = None + observation_type: ObservationType = ObservationType.USER + + arms: list[ArmResponse] + contexts: Optional[list[ContextResponse]] = None + clients: Optional[list[Client]] = None + + model_config = ConfigDict(from_attributes=True) diff --git a/backend/app/mab/__init__.py b/backend/app/mab/__init__.py deleted file mode 100644 index fa07d07..0000000 --- a/backend/app/mab/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .routers import router # noqa: F401 diff --git a/backend/app/mab/models.py b/backend/app/mab/models.py deleted file mode 100644 index f14d483..0000000 --- a/backend/app/mab/models.py +++ /dev/null @@ -1,419 +0,0 @@ -from datetime import datetime, timezone -from typing import Sequence - -from sqlalchemy import ( - Float, - ForeignKey, - and_, - delete, - select, -) -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import Mapped, mapped_column, relationship - -from ..models import ( - ArmBaseDB, - DrawsBaseDB, - ExperimentBaseDB, - NotificationsDB, -) -from ..schemas import ObservationType -from .schemas import MultiArmedBandit - - -class MultiArmedBanditDB(ExperimentBaseDB): - """ - ORM for managing experiments. - """ - - __tablename__ = "mabs" - - experiment_id: Mapped[int] = mapped_column( - ForeignKey("experiments_base.experiment_id", ondelete="CASCADE"), - primary_key=True, - nullable=False, - ) - arms: Mapped[list["MABArmDB"]] = relationship( - "MABArmDB", back_populates="experiment", lazy="joined" - ) - - draws: Mapped[list["MABDrawDB"]] = relationship( - "MABDrawDB", back_populates="experiment", lazy="joined" - ) - - __mapper_args__ = {"polymorphic_identity": "mabs"} - - def to_dict(self) -> dict: - """ - Convert the ORM object to a dictionary. - """ - return { - "experiment_id": self.experiment_id, - "user_id": self.user_id, - "workspace_id": self.workspace_id, - "name": self.name, - "description": self.description, - "sticky_assignment": self.sticky_assignment, - "auto_fail": self.auto_fail, - "auto_fail_value": self.auto_fail_value, - "auto_fail_unit": self.auto_fail_unit, - "created_datetime_utc": self.created_datetime_utc, - "is_active": self.is_active, - "n_trials": self.n_trials, - "arms": [arm.to_dict() for arm in self.arms], - "prior_type": self.prior_type, - "reward_type": self.reward_type, - } - - -class MABArmDB(ArmBaseDB): - """ - ORM for managing arms of an experiment - """ - - __tablename__ = "mab_arms" - - arm_id: Mapped[int] = mapped_column( - ForeignKey("arms_base.arm_id", ondelete="CASCADE"), - primary_key=True, - nullable=False, - ) - - # prior variables for MAB arms - alpha: Mapped[float] = mapped_column(Float, nullable=True) - beta: Mapped[float] = mapped_column(Float, nullable=True) - mu: Mapped[float] = mapped_column(Float, nullable=True) - sigma: Mapped[float] = mapped_column(Float, nullable=True) - alpha_init: Mapped[float] = mapped_column(Float, nullable=True) - beta_init: Mapped[float] = mapped_column(Float, nullable=True) - mu_init: Mapped[float] = mapped_column(Float, nullable=True) - sigma_init: Mapped[float] = mapped_column(Float, nullable=True) - experiment: Mapped[MultiArmedBanditDB] = relationship( - "MultiArmedBanditDB", back_populates="arms", lazy="joined" - ) - - draws: Mapped[list["MABDrawDB"]] = relationship( - "MABDrawDB", back_populates="arm", lazy="joined" - ) - - __mapper_args__ = {"polymorphic_identity": "mab_arms"} - - def to_dict(self) -> dict: - """ - Convert the ORM object to a dictionary. - """ - return { - "arm_id": self.arm_id, - "name": self.name, - "description": self.description, - "alpha": self.alpha, - "beta": self.beta, - "mu": self.mu, - "sigma": self.sigma, - "alpha_init": self.alpha_init, - "beta_init": self.beta_init, - "mu_init": self.mu_init, - "sigma_init": self.sigma_init, - "draws": [draw.to_dict() for draw in self.draws], - } - - -class MABDrawDB(DrawsBaseDB): - """ - ORM for managing draws of an experiment - """ - - __tablename__ = "mab_draws" - - draw_id: Mapped[str] = mapped_column( - ForeignKey("draws_base.draw_id", ondelete="CASCADE"), - primary_key=True, - nullable=False, - ) - - arm: Mapped[MABArmDB] = relationship( - "MABArmDB", back_populates="draws", lazy="joined" - ) - experiment: Mapped[MultiArmedBanditDB] = relationship( - "MultiArmedBanditDB", back_populates="draws", lazy="joined" - ) - - __mapper_args__ = {"polymorphic_identity": "mab_draws"} - - def to_dict(self) -> dict: - """ - Convert the ORM object to a dictionary. - """ - return { - "draw_id": self.draw_id, - "client_id": self.client_id, - "draw_datetime_utc": self.draw_datetime_utc, - "arm_id": self.arm_id, - "experiment_id": self.experiment_id, - "user_id": self.user_id, - "reward": self.reward, - "observation_type": self.observation_type, - "observed_datetime_utc": self.observed_datetime_utc, - } - - -async def save_mab_to_db( - experiment: MultiArmedBandit, - user_id: int, - workspace_id: int, - asession: AsyncSession, -) -> MultiArmedBanditDB: - """ - Save the experiment to the database. - """ - arms = [ - MABArmDB( - name=arm.name, - description=arm.description, - alpha_init=arm.alpha_init, - beta_init=arm.beta_init, - mu_init=arm.mu_init, - sigma_init=arm.sigma_init, - n_outcomes=arm.n_outcomes, - alpha=arm.alpha_init, - beta=arm.beta_init, - mu=arm.mu_init, - sigma=arm.sigma_init, - user_id=user_id, - ) - for arm in experiment.arms - ] - experiment_db = MultiArmedBanditDB( - name=experiment.name, - description=experiment.description, - user_id=user_id, - workspace_id=workspace_id, - is_active=experiment.is_active, - created_datetime_utc=datetime.now(timezone.utc), - n_trials=0, - arms=arms, - sticky_assignment=experiment.sticky_assignment, - auto_fail=experiment.auto_fail, - auto_fail_value=experiment.auto_fail_value, - auto_fail_unit=experiment.auto_fail_unit, - prior_type=experiment.prior_type.value, - reward_type=experiment.reward_type.value, - ) - - asession.add(experiment_db) - await asession.commit() - await asession.refresh(experiment_db) - - return experiment_db - - -async def get_all_mabs( - workspace_id: int, - asession: AsyncSession, -) -> Sequence[MultiArmedBanditDB]: - """ - Get all the experiments from the database for a specific workspace. - """ - statement = ( - select(MultiArmedBanditDB) - .where( - MultiArmedBanditDB.workspace_id == workspace_id, - ) - .order_by(MultiArmedBanditDB.experiment_id) - ) - - return (await asession.execute(statement)).unique().scalars().all() - - -async def get_mab_by_id( - experiment_id: int, - workspace_id: int, - asession: AsyncSession, -) -> MultiArmedBanditDB | None: - """ - Get the experiment by id from a specific workspace. - """ - conditions = [ - MultiArmedBanditDB.workspace_id == workspace_id, - MultiArmedBanditDB.experiment_id == experiment_id, - ] - - result = await asession.execute(select(MultiArmedBanditDB).where(and_(*conditions))) - - return result.unique().scalar_one_or_none() - - -async def delete_mab_by_id( - experiment_id: int, workspace_id: int, asession: AsyncSession -) -> None: - """ - Delete the experiment by id. - """ - await asession.execute( - delete(NotificationsDB).where(NotificationsDB.experiment_id == experiment_id) - ) - - await asession.execute( - delete(DrawsBaseDB).where(DrawsBaseDB.experiment_id == experiment_id) - ) - - await asession.execute( - delete(MABArmDB).where( - and_( - MABArmDB.arm_id == ArmBaseDB.arm_id, - ArmBaseDB.experiment_id == experiment_id, - ) - ) - ) - await asession.execute( - delete(MultiArmedBanditDB).where( - and_( - MultiArmedBanditDB.experiment_id == experiment_id, - MultiArmedBanditDB.experiment_id == ExperimentBaseDB.experiment_id, - MultiArmedBanditDB.workspace_id == workspace_id, - ) - ) - ) - await asession.commit() - return None - - -async def get_obs_by_experiment_arm_id( - experiment_id: int, arm_id: int, asession: AsyncSession -) -> Sequence[MABDrawDB]: - """ - Get the observations for the experiment and arm. - """ - statement = ( - select(MABDrawDB) - .where(MABDrawDB.experiment_id == experiment_id) - .where(MABDrawDB.reward.is_not(None)) - .where(MABDrawDB.arm_id == arm_id) - .order_by(MABDrawDB.observed_datetime_utc) - ) - - return (await asession.execute(statement)).unique().scalars().all() - - -async def get_all_obs_by_experiment_id( - experiment_id: int, - workspace_id: int, - asession: AsyncSession, -) -> Sequence[MABDrawDB]: - """ - Get the observations for the experiment. - """ - # First, verify experiment belongs to the workspace - experiment = await get_mab_by_id( - experiment_id=experiment_id, - workspace_id=workspace_id, - asession=asession, - ) - - if experiment is None: - # Return empty list if experiment doesn't exist or doesn't belong to workspace - return [] - - statement = ( - select(MABDrawDB) - .where(MABDrawDB.experiment_id == experiment_id) - .where(MABDrawDB.reward.is_not(None)) - .order_by(MABDrawDB.observed_datetime_utc) - ) - - return (await asession.execute(statement)).unique().scalars().all() - - -async def get_draw_by_id(draw_id: str, asession: AsyncSession) -> MABDrawDB | None: - """ - Get a draw by its ID, which should be unique across the system. - """ - statement = select(MABDrawDB).where(MABDrawDB.draw_id == draw_id) - result = await asession.execute(statement) - - return result.unique().scalar_one_or_none() - - -async def get_draw_by_client_id( - client_id: str, - experiment_id: int, - asession: AsyncSession, -) -> MABDrawDB | None: - """ - Get a draw by its client ID for a specific experiment. - """ - statement = ( - select(MABDrawDB) - .where(MABDrawDB.client_id == client_id) - .where(MABDrawDB.client_id.is_not(None)) - .where(MABDrawDB.experiment_id == experiment_id) - ) - result = await asession.execute(statement) - - return result.unique().scalars().first() - - -async def save_draw_to_db( - experiment_id: int, - arm_id: int, - draw_id: str, - client_id: str | None, - user_id: int | None, - asession: AsyncSession, - workspace_id: int | None = None, -) -> MABDrawDB: - """ - Save a draw to the database - """ - # If user_id is not provided but needed, get it from the experiment - if user_id is None and workspace_id is not None: - experiment = await get_mab_by_id( - experiment_id=experiment_id, - workspace_id=workspace_id, - asession=asession, - ) - - if experiment: - user_id = experiment.user_id - else: - raise ValueError(f"Experiment with id {experiment_id} not found") - - if user_id is None: - raise ValueError("User ID must be provided or derivable from experiment") - - draw_datetime_utc: datetime = datetime.now(timezone.utc) - - draw = MABDrawDB( - draw_id=draw_id, - client_id=client_id, - experiment_id=experiment_id, - user_id=user_id, - arm_id=arm_id, - draw_datetime_utc=draw_datetime_utc, - ) - - asession.add(draw) - await asession.commit() - await asession.refresh(draw) - - return draw - - -async def save_observation_to_db( - draw: MABDrawDB, - reward: float, - asession: AsyncSession, - observation_type: ObservationType, -) -> MABDrawDB: - """ - Save an observation to the database - """ - - draw.reward = reward - draw.observed_datetime_utc = datetime.now(timezone.utc) - draw.observation_type = observation_type - asession.add(draw) - await asession.commit() - await asession.refresh(draw) - - return draw diff --git a/backend/app/mab/observation.py b/backend/app/mab/observation.py deleted file mode 100644 index 0ef34c2..0000000 --- a/backend/app/mab/observation.py +++ /dev/null @@ -1,94 +0,0 @@ -from datetime import datetime, timezone - -from fastapi import HTTPException -from sqlalchemy.ext.asyncio import AsyncSession - -from ..schemas import ObservationType, Outcome, RewardLikelihood -from .models import ( - MABArmDB, - MABDrawDB, - MultiArmedBanditDB, - save_observation_to_db, -) -from .sampling_utils import update_arm_params -from .schemas import ( - ArmResponse, - MultiArmedBanditSample, -) - - -async def update_based_on_outcome( - experiment: MultiArmedBanditDB, - draw: MABDrawDB, - outcome: float, - asession: AsyncSession, - observation_type: ObservationType, -) -> ArmResponse: - """ - Update the arm parameters based on the outcome. - - This is a helper function to allow `auto_fail` job to call - it as well. - """ - update_experiment_metadata(experiment) - - arm = get_arm_from_experiment(experiment, draw.arm_id) - arm.n_outcomes += 1 - - experiment_data = MultiArmedBanditSample.model_validate(experiment) - await update_arm_parameters(arm, experiment_data, outcome) - await save_updated_data(arm, draw, outcome, observation_type, asession) - - return ArmResponse.model_validate(arm) - - -def update_experiment_metadata(experiment: MultiArmedBanditDB) -> None: - """Update experiment metadata with new trial information""" - experiment.n_trials += 1 - experiment.last_trial_datetime_utc = datetime.now(tz=timezone.utc) - - -def get_arm_from_experiment(experiment: MultiArmedBanditDB, arm_id: int) -> MABArmDB: - """Get and validate the arm from the experiment""" - arms = [a for a in experiment.arms if a.arm_id == arm_id] - if not arms: - raise HTTPException(status_code=404, detail=f"Arm with id {arm_id} not found") - return arms[0] - - -async def update_arm_parameters( - arm: MABArmDB, experiment_data: MultiArmedBanditSample, outcome: float -) -> None: - """Update the arm parameters based on the reward type and outcome""" - if experiment_data.reward_type == RewardLikelihood.BERNOULLI: - Outcome(outcome) # Check if reward is 0 or 1 - arm.alpha, arm.beta = update_arm_params( - ArmResponse.model_validate(arm), - experiment_data.prior_type, - experiment_data.reward_type, - outcome, - ) - elif experiment_data.reward_type == RewardLikelihood.NORMAL: - arm.mu, arm.sigma = update_arm_params( - ArmResponse.model_validate(arm), - experiment_data.prior_type, - experiment_data.reward_type, - outcome, - ) - else: - raise HTTPException( - status_code=400, - detail="Reward type not supported.", - ) - - -async def save_updated_data( - arm: MABArmDB, - draw: MABDrawDB, - outcome: float, - observation_type: ObservationType, - asession: AsyncSession, -) -> None: - """Save the updated arm and observation data""" - await asession.commit() - await save_observation_to_db(draw, outcome, asession, observation_type) diff --git a/backend/app/mab/routers.py b/backend/app/mab/routers.py deleted file mode 100644 index 9a22582..0000000 --- a/backend/app/mab/routers.py +++ /dev/null @@ -1,357 +0,0 @@ -from typing import Annotated, Optional -from uuid import uuid4 - -from fastapi import APIRouter, Depends -from fastapi.exceptions import HTTPException -from sqlalchemy.ext.asyncio import AsyncSession - -from ..auth.dependencies import ( - authenticate_workspace_key, - get_verified_user, - require_admin_role, -) -from ..database import get_async_session -from ..models import get_notifications_from_db, save_notifications_to_db -from ..schemas import NotificationsResponse, ObservationType -from ..users.models import UserDB -from ..utils import setup_logger -from ..workspaces.models import ( - WorkspaceDB, - get_user_default_workspace, -) -from .models import ( - MABDrawDB, - MultiArmedBanditDB, - delete_mab_by_id, - get_all_mabs, - get_all_obs_by_experiment_id, - get_draw_by_client_id, - get_draw_by_id, - get_mab_by_id, - save_draw_to_db, - save_mab_to_db, -) -from .observation import update_based_on_outcome -from .sampling_utils import choose_arm -from .schemas import ( - ArmResponse, - MABDrawResponse, - MABObservationResponse, - MultiArmedBandit, - MultiArmedBanditResponse, - MultiArmedBanditSample, -) - -router = APIRouter(prefix="/mab", tags=["Multi-Armed Bandits"]) - -logger = setup_logger(__name__) - - -@router.post("/", response_model=MultiArmedBanditResponse) -async def create_mab( - experiment: MultiArmedBandit, - user_db: Annotated[UserDB, Depends(require_admin_role)], - asession: AsyncSession = Depends(get_async_session), -) -> MultiArmedBanditResponse: - """ - Create a new experiment in the user's current workspace. - """ - workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - - if workspace_db is None: - raise HTTPException( - status_code=404, - detail="Workspace not found. Please create a workspace first.", - ) - - mab = await save_mab_to_db( - experiment, user_db.user_id, workspace_db.workspace_id, asession - ) - - notifications = await save_notifications_to_db( - experiment_id=mab.experiment_id, - user_id=user_db.user_id, - notifications=experiment.notifications, - asession=asession, - ) - - mab_dict = mab.to_dict() - mab_dict["notifications"] = [n.to_dict() for n in notifications] - - return MultiArmedBanditResponse.model_validate(mab_dict) - - -@router.get("/", response_model=list[MultiArmedBanditResponse]) -async def get_mabs( - user_db: Annotated[UserDB, Depends(get_verified_user)], - asession: AsyncSession = Depends(get_async_session), -) -> list[MultiArmedBanditResponse]: - """ - Get details of all experiments in the user's current workspace. - """ - workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - - if workspace_db is None: - raise HTTPException( - status_code=404, - detail="Workspace not found. Please create a workspace first.", - ) - - experiments = await get_all_mabs(workspace_db.workspace_id, asession) - - all_experiments = [] - for exp in experiments: - exp_dict = exp.to_dict() - exp_dict["notifications"] = [ - n.to_dict() - for n in await get_notifications_from_db( - exp.experiment_id, exp.user_id, asession - ) - ] - all_experiments.append( - MultiArmedBanditResponse.model_validate( - { - **exp_dict, - "notifications": [ - NotificationsResponse(**n) for n in exp_dict["notifications"] - ], - } - ) - ) - return all_experiments - - -@router.get("/{experiment_id}/", response_model=MultiArmedBanditResponse) -async def get_mab( - experiment_id: int, - user_db: Annotated[UserDB, Depends(get_verified_user)], - asession: AsyncSession = Depends(get_async_session), -) -> MultiArmedBanditResponse: - """ - Get details of experiment with the provided `experiment_id`. - """ - workspace_db = await get_user_default_workspace(asession=asession, user_db=user_db) - - if workspace_db is None: - raise HTTPException( - status_code=404, - detail="Workspace not found. Please create a workspace first.", - ) - - experiment = await get_mab_by_id(experiment_id, workspace_db.workspace_id, asession) - - if experiment is None: - raise HTTPException( - status_code=404, detail=f"Experiment with id {experiment_id} not found" - ) - - experiment_dict = experiment.to_dict() - experiment_dict["notifications"] = [ - n.to_dict() - for n in await get_notifications_from_db( - experiment.experiment_id, experiment.user_id, asession - ) - ] - - return MultiArmedBanditResponse.model_validate(experiment_dict) - - -@router.delete("/{experiment_id}", response_model=dict) -async def delete_mab( - experiment_id: int, - user_db: Annotated[UserDB, Depends(require_admin_role)], - asession: AsyncSession = Depends(get_async_session), -) -> dict: - """ - Delete the experiment with the provided `experiment_id`. - """ - try: - workspace_db = await get_user_default_workspace( - asession=asession, user_db=user_db - ) - - if workspace_db is None: - raise HTTPException( - status_code=404, - detail="Workspace not found. Please create a workspace first.", - ) - - experiment = await get_mab_by_id( - experiment_id, workspace_db.workspace_id, asession - ) - if experiment is None: - raise HTTPException( - status_code=404, detail=f"Experiment with id {experiment_id} not found" - ) - await delete_mab_by_id(experiment_id, workspace_db.workspace_id, asession) - return {"message": f"Experiment with id {experiment_id} deleted successfully."} - except Exception as e: - raise HTTPException(status_code=500, detail=f"Error: {e}") from e - - -@router.get("/{experiment_id}/draw", response_model=MABDrawResponse) -async def draw_arm( - experiment_id: int, - draw_id: Optional[str] = None, - client_id: Optional[str] = None, - workspace_db: WorkspaceDB = Depends(authenticate_workspace_key), - asession: AsyncSession = Depends(get_async_session), -) -> MABDrawResponse: - """ - Draw an arm for the provided experiment. - """ - # Get workspace from user context - workspace_id = workspace_db.workspace_id - - experiment = await get_mab_by_id(experiment_id, workspace_id, asession) - if experiment is None: - raise HTTPException( - status_code=404, detail=f"Experiment with id {experiment_id} not found" - ) - - if experiment.sticky_assignment and client_id is None: - raise HTTPException( - status_code=400, - detail="Client ID is required for sticky assignment.", - ) - - # Check for existing draws - if draw_id is None: - draw_id = str(uuid4()) - - existing_draw = await get_draw_by_id(draw_id, asession) - if existing_draw: - raise HTTPException( - status_code=400, - detail=f"Draw ID {draw_id} already exists.", - ) - - experiment_data = MultiArmedBanditSample.model_validate(experiment) - chosen_arm = choose_arm(experiment=experiment_data) - chosen_arm_id = experiment.arms[chosen_arm].arm_id - - # If sticky assignment, check if the client_id has a previous arm assigned - if experiment.sticky_assignment and client_id: - previous_draw = await get_draw_by_client_id( - client_id=client_id, - experiment_id=experiment.experiment_id, - asession=asession, - ) - if previous_draw: - print(f"Previous draw found: {previous_draw.arm_id}") - chosen_arm_id = previous_draw.arm_id - - try: - _ = await save_draw_to_db( - experiment_id=experiment.experiment_id, - arm_id=chosen_arm_id, - draw_id=draw_id, - client_id=client_id, - user_id=None, - asession=asession, - workspace_id=workspace_id, - ) - except Exception as e: - raise HTTPException( - status_code=500, - detail=f"Error saving draw to database: {e}", - ) from e - - return MABDrawResponse.model_validate( - { - "draw_id": draw_id, - "client_id": client_id, - "arm": ArmResponse.model_validate( - [arm for arm in experiment.arms if arm.arm_id == chosen_arm_id][0] - ), - } - ) - - -@router.put("/{experiment_id}/{draw_id}/{outcome}", response_model=ArmResponse) -async def update_arm( - experiment_id: int, - draw_id: str, - outcome: float, - workspace_db: WorkspaceDB = Depends(authenticate_workspace_key), - asession: AsyncSession = Depends(get_async_session), -) -> ArmResponse: - """ - Update the arm with the provided `arm_id` for the given - `experiment_id` based on the `outcome`. - """ - # Get workspace from user context - workspace_id = workspace_db.workspace_id - - experiment, draw = await validate_experiment_and_draw( - experiment_id, draw_id, workspace_id, asession - ) - - return await update_based_on_outcome( - experiment, draw, outcome, asession, ObservationType.USER - ) - - -@router.get( - "/{experiment_id}/outcomes", - response_model=list[MABObservationResponse], -) -async def get_outcomes( - experiment_id: int, - workspace_db: WorkspaceDB = Depends(authenticate_workspace_key), - asession: AsyncSession = Depends(get_async_session), -) -> list[MABObservationResponse]: - """ - Get the outcomes for the experiment. - """ - # Get workspace from user context - workspace_id = workspace_db.workspace_id - - experiment = await get_mab_by_id(experiment_id, workspace_id, asession) - if not experiment: - raise HTTPException( - status_code=404, detail=f"Experiment with id {experiment_id} not found" - ) - - rewards = await get_all_obs_by_experiment_id( - experiment_id=experiment.experiment_id, - workspace_id=workspace_id, - asession=asession, - ) - - return [MABObservationResponse.model_validate(reward) for reward in rewards] - - -async def validate_experiment_and_draw( - experiment_id: int, - draw_id: str, - workspace_id: int, - asession: AsyncSession, -) -> tuple[MultiArmedBanditDB, MABDrawDB]: - """Validate the experiment and draw""" - experiment = await get_mab_by_id(experiment_id, workspace_id, asession) - if experiment is None: - raise HTTPException( - status_code=404, detail=f"Experiment with id {experiment_id} not found" - ) - - draw = await get_draw_by_id(draw_id=draw_id, asession=asession) - if draw is None: - raise HTTPException(status_code=404, detail=f"Draw with id {draw_id} not found") - - if draw.experiment_id != experiment_id: - raise HTTPException( - status_code=400, - detail=( - f"Draw with id {draw_id} does not belong " - f"to experiment with id {experiment_id}", - ), - ) - - if draw.reward is not None: - raise HTTPException( - status_code=400, - detail=f"Draw with id {draw_id} already has an outcome.", - ) - - return experiment, draw diff --git a/backend/app/mab/sampling_utils.py b/backend/app/mab/sampling_utils.py deleted file mode 100644 index 6bc5fe8..0000000 --- a/backend/app/mab/sampling_utils.py +++ /dev/null @@ -1,138 +0,0 @@ -import numpy as np -from numpy.random import beta, normal - -from ..mab.schemas import ArmResponse, MultiArmedBanditSample -from ..schemas import ArmPriors, Outcome, RewardLikelihood - - -def sample_beta_binomial(alphas: np.ndarray, betas: np.ndarray) -> int: - """ - Thompson Sampling with Beta-Binomial distribution. - - Parameters - ---------- - alphas : alpha parameter of Beta distribution for each arm - betas : beta parameter of Beta distribution for each arm - """ - samples = beta(alphas, betas) - return int(samples.argmax()) - - -def sample_normal(mus: np.ndarray, sigmas: np.ndarray) -> int: - """ - Thompson Sampling with conjugate normal distribution. - - Parameters - ---------- - mus: mean of Normal distribution for each arm - sigmas: standard deviation of Normal distribution for each arm - """ - samples = normal(loc=mus, scale=sigmas) - return int(samples.argmax()) - - -def update_arm_beta_binomial( - alpha: float, beta: float, reward: Outcome -) -> tuple[float, float]: - """ - Update the alpha and beta parameters of the Beta distribution. - - Parameters - ---------- - alpha : int - The alpha parameter of the Beta distribution. - beta : int - The beta parameter of the Beta distribution. - reward : Outcome - The reward of the arm. - """ - if reward == Outcome.SUCCESS: - - return alpha + 1, beta - else: - return alpha, beta + 1 - - -def update_arm_normal( - current_mu: float, current_sigma: float, reward: float, sigma_llhood: float -) -> tuple[float, float]: - """ - Update the mean and standard deviation of the Normal distribution. - - Parameters - ---------- - current_mu : The mean of the Normal distribution. - current_sigma : The standard deviation of the Normal distribution. - reward : The reward of the arm. - sigma_llhood : The likelihood of the standard deviation. - """ - denom = sigma_llhood**2 + current_sigma**2 - new_sigma = sigma_llhood * current_sigma / np.sqrt(denom) - new_mu = (current_mu * sigma_llhood**2 + reward * current_sigma**2) / denom - return new_mu, new_sigma - - -def choose_arm(experiment: MultiArmedBanditSample) -> int: - """ - Choose arm based on posterior - - Parameters - ---------- - experiment : MultiArmedBanditResponse - The experiment data containing priors and rewards for each arm. - """ - if (experiment.prior_type == ArmPriors.BETA) and ( - experiment.reward_type == RewardLikelihood.BERNOULLI - ): - alphas = np.array([arm.alpha for arm in experiment.arms]) - betas = np.array([arm.beta for arm in experiment.arms]) - - return sample_beta_binomial(alphas=alphas, betas=betas) - - elif (experiment.prior_type == ArmPriors.NORMAL) and ( - experiment.reward_type == RewardLikelihood.NORMAL - ): - mus = np.array([arm.mu for arm in experiment.arms]) - sigmas = np.array([arm.sigma for arm in experiment.arms]) - # TODO: add support for non-std sigma_llhood - return sample_normal(mus=mus, sigmas=sigmas) - else: - raise ValueError("Prior and reward type combination is not supported.") - - -def update_arm_params( - arm: ArmResponse, - prior_type: ArmPriors, - reward_type: RewardLikelihood, - reward: float, -) -> tuple: - """ - Update the arm with the provided `arm_id` based on the `reward`. - - Parameters - ---------- - arm: The arm to update. - prior_type: The type of prior distribution for the arms. - reward_type: The likelihood distribution of the reward. - reward: The reward of the arm. - """ - - if (prior_type == ArmPriors.BETA) and (reward_type == RewardLikelihood.BERNOULLI): - if arm.alpha is None or arm.beta is None: - raise ValueError("Beta prior requires alpha and beta.") - outcome = Outcome(reward) - return update_arm_beta_binomial(alpha=arm.alpha, beta=arm.beta, reward=outcome) - - elif ( - (prior_type == ArmPriors.NORMAL) - and (reward_type == RewardLikelihood.NORMAL) - and (arm.mu and arm.sigma) - ): - return update_arm_normal( - current_mu=arm.mu, - current_sigma=arm.sigma, - reward=reward, - sigma_llhood=1.0, # TODO: add support for non-std sigma_llhood - ) - else: - raise ValueError("Prior and reward type combination is not supported.") diff --git a/backend/app/mab/schemas.py b/backend/app/mab/schemas.py deleted file mode 100644 index 60fbf0e..0000000 --- a/backend/app/mab/schemas.py +++ /dev/null @@ -1,262 +0,0 @@ -from datetime import datetime -from typing import Optional, Self - -from pydantic import BaseModel, ConfigDict, Field, model_validator - -from ..schemas import ( - ArmPriors, - AutoFailUnitType, - Notifications, - NotificationsResponse, - RewardLikelihood, - allowed_combos_mab, -) - - -class Arm(BaseModel): - """ - Pydantic model for a arm of the experiment. - """ - - name: str = Field( - max_length=150, - examples=["Arm 1"], - ) - description: str = Field( - max_length=500, - examples=["This is a description of the arm."], - ) - - # prior variables - alpha_init: Optional[float] = Field( - default=None, examples=[None, 1.0], description="Alpha parameter for Beta prior" - ) - beta_init: Optional[float] = Field( - default=None, examples=[None, 1.0], description="Beta parameter for Beta prior" - ) - mu_init: Optional[float] = Field( - default=None, - examples=[None, 0.0], - description="Mean parameter for Normal prior", - ) - sigma_init: Optional[float] = Field( - default=None, - examples=[None, 1.0], - description="Standard deviation parameter for Normal prior", - ) - n_outcomes: Optional[int] = Field( - default=0, - description="Number of outcomes for the arm", - examples=[0, 10, 15], - ) - - @model_validator(mode="after") - def check_values(self) -> Self: - """ - Check if the values are unique. - """ - alpha = self.alpha_init - beta = self.beta_init - sigma = self.sigma_init - if alpha is not None and alpha <= 0: - raise ValueError("Alpha must be greater than 0.") - if beta is not None and beta <= 0: - raise ValueError("Beta must be greater than 0.") - if sigma is not None and sigma <= 0: - raise ValueError("Sigma must be greater than 0.") - return self - - -class ArmResponse(Arm): - """ - Pydantic model for an response for arm creation - """ - - arm_id: int - alpha: Optional[float] - beta: Optional[float] - mu: Optional[float] - sigma: Optional[float] - model_config = ConfigDict( - from_attributes=True, - ) - - -class MultiArmedBanditBase(BaseModel): - """ - Pydantic model for an experiment - Base model. - Note: Do not use this model directly. Use `MultiArmedBandit` instead. - """ - - name: str = Field( - max_length=150, - examples=["Experiment 1"], - ) - - description: str = Field( - max_length=500, - examples=["This is a description of the experiment."], - ) - - sticky_assignment: bool = Field( - description="Whether the arm assignment is sticky or not.", - default=False, - ) - - auto_fail: bool = Field( - description=( - "Whether the experiment should fail automatically after " - "a certain period if no outcome is registered." - ), - default=False, - ) - - auto_fail_value: Optional[int] = Field( - description="The time period after which the experiment should fail.", - default=None, - ) - - auto_fail_unit: Optional[AutoFailUnitType] = Field( - description="The time unit for the auto fail period.", - default=None, - ) - - reward_type: RewardLikelihood = Field( - description="The type of reward we observe from the experiment.", - default=RewardLikelihood.BERNOULLI, - ) - prior_type: ArmPriors = Field( - description="The type of prior distribution for the arms.", - default=ArmPriors.BETA, - ) - - is_active: bool = True - - model_config = ConfigDict(from_attributes=True) - - -class MultiArmedBandit(MultiArmedBanditBase): - """ - Pydantic model for an experiment. - """ - - arms: list[Arm] - notifications: Notifications - - @model_validator(mode="after") - def auto_fail_unit_and_value_set(self) -> Self: - """ - Validate that the auto fail unit and value are set if auto fail is True. - """ - if self.auto_fail: - if ( - not self.auto_fail_value - or not self.auto_fail_unit - or self.auto_fail_value <= 0 - ): - raise ValueError( - ( - "Auto fail is enabled. " - "Please provide both auto_fail_value and auto_fail_unit." - ) - ) - return self - - @model_validator(mode="after") - def arms_at_least_two(self) -> Self: - """ - Validate that the experiment has at least two arms. - """ - if len(self.arms) < 2: - raise ValueError("The experiment must have at least two arms.") - return self - - @model_validator(mode="after") - def check_prior_reward_type_combo(self) -> Self: - """ - Validate that the prior and reward type combination is allowed. - """ - - if (self.prior_type, self.reward_type) not in allowed_combos_mab: - raise ValueError("Prior and reward type combo not supported.") - return self - - @model_validator(mode="after") - def check_arm_missing_params(self) -> Self: - """ - Check if the arm reward type is same as the experiment reward type. - """ - prior_type = self.prior_type - arms = self.arms - - prior_params = { - ArmPriors.BETA: ("alpha_init", "beta_init"), - ArmPriors.NORMAL: ("mu_init", "sigma_init"), - } - - for arm in arms: - arm_dict = arm.model_dump() - if prior_type in prior_params: - missing_params = [] - for param in prior_params[prior_type]: - if param not in arm_dict.keys(): - missing_params.append(param) - elif arm_dict[param] is None: - missing_params.append(param) - - if missing_params: - val = prior_type.value - raise ValueError(f"{val} prior needs {','.join(missing_params)}.") - return self - - model_config = ConfigDict(from_attributes=True) - - -class MultiArmedBanditResponse(MultiArmedBanditBase): - """ - Pydantic model for an response for experiment creation. - Returns the id of the experiment and the arms - """ - - experiment_id: int - workspace_id: int - arms: list[ArmResponse] - notifications: list[NotificationsResponse] - created_datetime_utc: datetime - last_trial_datetime_utc: Optional[datetime] = None - n_trials: int - model_config = ConfigDict(from_attributes=True, revalidate_instances="always") - - -class MultiArmedBanditSample(MultiArmedBanditBase): - """ - Pydantic model for an experiment sample. - """ - - experiment_id: int - arms: list[ArmResponse] - - -class MABObservationResponse(BaseModel): - """ - Pydantic model for binary observations of the experiment. - """ - - experiment_id: int - arm_id: int - reward: float - draw_id: str - client_id: str | None - observed_datetime_utc: datetime - - model_config = ConfigDict(from_attributes=True) - - -class MABDrawResponse(BaseModel): - """ - Pydantic model for the response of the draw endpoint. - """ - - draw_id: str - client_id: str | None - arm: ArmResponse diff --git a/backend/app/models.py b/backend/app/models.py index 097aa2b..77833b5 100644 --- a/backend/app/models.py +++ b/backend/app/models.py @@ -13,12 +13,12 @@ select, ) from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column from .schemas import AutoFailUnitType, EventType, Notifications, ObservationType if TYPE_CHECKING: - from .workspaces.models import WorkspaceDB + pass class Base(DeclarativeBase): @@ -66,9 +66,6 @@ class ExperimentBaseDB(Base): last_trial_datetime_utc: Mapped[datetime] = mapped_column( DateTime(timezone=True), nullable=True ) - workspace: Mapped["WorkspaceDB"] = relationship( - "WorkspaceDB", back_populates="experiments" - ) __mapper_args__ = { "polymorphic_identity": "experiment", @@ -161,7 +158,7 @@ class NotificationsDB(Base): the background celery job """ - __tablename__ = "notifications" + __tablename__ = "notifications_db" notification_id: Mapped[int] = mapped_column( Integer, primary_key=True, nullable=False diff --git a/backend/app/workspaces/models.py b/backend/app/workspaces/models.py index b79aeab..52146cc 100644 --- a/backend/app/workspaces/models.py +++ b/backend/app/workspaces/models.py @@ -19,7 +19,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Mapped, mapped_column, relationship -from ..models import Base, ExperimentBaseDB +from ..models import Base from ..users.exceptions import UserNotFoundError from ..users.schemas import UserCreate from .schemas import UserCreateWithCode, UserRoles @@ -77,9 +77,6 @@ class WorkspaceDB(Base): workspace_id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False) workspace_name: Mapped[str] = mapped_column(String, nullable=False, unique=True) is_default: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) - experiments: Mapped[list["ExperimentBaseDB"]] = relationship( - "ExperimentBaseDB", back_populates="workspace", cascade="all, delete-orphan" - ) pending_invitations: Mapped[list["PendingInvitationDB"]] = relationship( "PendingInvitationDB", back_populates="workspace", cascade="all, delete-orphan" diff --git a/backend/jobs/auto_fail.py b/backend/jobs/auto_fail.py index 3665831..2323a56 100644 --- a/backend/jobs/auto_fail.py +++ b/backend/jobs/auto_fail.py @@ -15,21 +15,16 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from app.bayes_ab.models import BayesianABDB, BayesianABDrawDB -from app.bayes_ab.observation import ( - update_based_on_outcome as bayes_ab_update_based_on_outcome, -) -from app.contextual_mab.models import ContextualBanditDB, ContextualDrawDB -from app.contextual_mab.observation import ( - update_based_on_outcome as cmab_update_based_on_outcome, -) from app.database import get_async_session -from app.mab.models import MABDrawDB, MultiArmedBanditDB -from app.mab.observation import update_based_on_outcome as mab_update_based_on_outcome -from app.schemas import ObservationType +from app.experiments.dependencies import ( + format_rewards_for_arm_update, + update_arm_based_on_outcome, +) +from app.experiments.models import DrawDB, ExperimentDB +from app.experiments.schemas import ObservationType -async def auto_fail_mab(asession: AsyncSession) -> int: +async def auto_fail_experiment(asession: AsyncSession) -> int: """ Auto fail experiments draws that have not been updated in a certain amount of time. @@ -43,9 +38,7 @@ async def auto_fail_mab(asession: AsyncSession) -> int: now = datetime.now(tz=timezone.utc) # Fetch all required experiments data in one query - experiment_query = select(MultiArmedBanditDB).where( - MultiArmedBanditDB.auto_fail.is_(True) - ) + experiment_query = select(ExperimentDB).where(ExperimentDB.auto_fail.is_(True)) experiments_result = (await asession.execute(experiment_query)).unique() experiments = experiments_result.scalars().all() for experiment in experiments: @@ -58,15 +51,15 @@ async def auto_fail_mab(asession: AsyncSession) -> int: cutoff_datetime = now - timedelta(hours=hours_threshold) draws_query = ( - select(MABDrawDB) + select(DrawDB) .join( - MultiArmedBanditDB, - MABDrawDB.experiment_id == MultiArmedBanditDB.experiment_id, + ExperimentDB, + DrawDB.experiment_id == ExperimentDB.experiment_id, ) .where( - MABDrawDB.experiment_id == experiment.experiment_id, - MABDrawDB.observation_type.is_(None), - MABDrawDB.draw_datetime_utc <= cutoff_datetime, + DrawDB.experiment_id == experiment.experiment_id, + DrawDB.observation_type.is_(None), + DrawDB.draw_datetime_utc <= cutoff_datetime, ) .limit(100) ) # Process in smaller batches @@ -83,146 +76,17 @@ async def auto_fail_mab(asession: AsyncSession) -> int: for draw in draws_batch: draw.observation_type = ObservationType.AUTO - await mab_update_based_on_outcome( - experiment, - draw, - 0.0, - asession, - ObservationType.AUTO, - ) - - total_failed += 1 - - await asession.commit() - offset += len(draws_batch) - - return total_failed - - -async def auto_fail_bayes_ab(asession: AsyncSession) -> int: - """ - Auto fail experiments draws that have not been updated in a certain amount of time. - - """ - total_failed = 0 - now = datetime.now(tz=timezone.utc) - - # Fetch all required experiments data in one query - experiment_query = select(BayesianABDB).where(BayesianABDB.auto_fail.is_(True)) - experiments_result = (await asession.execute(experiment_query)).unique() - experiments = experiments_result.scalars().all() - for experiment in experiments: - hours_threshold = ( - experiment.auto_fail_value * 24 - if experiment.auto_fail_unit == "days" - else experiment.auto_fail_value - ) - - cutoff_datetime = now - timedelta(hours=hours_threshold) - - draws_query = ( - select(BayesianABDrawDB) - .join( - BayesianABDB, - BayesianABDrawDB.experiment_id == BayesianABDB.experiment_id, - ) - .where( - BayesianABDrawDB.experiment_id == experiment.experiment_id, - BayesianABDrawDB.observation_type.is_(None), - BayesianABDrawDB.draw_datetime_utc <= cutoff_datetime, - ) - .limit(100) - ) # Process in smaller batches - - # Paginate through results if there are many draws to avoid memory issues - offset = 0 - while True: - batch_query = draws_query.offset(offset) - draws_result = (await asession.execute(batch_query)).unique() - draws_batch = draws_result.scalars().all() - if not draws_batch: - break - - for draw in draws_batch: - draw.observation_type = ObservationType.AUTO - - await bayes_ab_update_based_on_outcome( - experiment, - draw, - 0.0, - asession, - ObservationType.AUTO, + rewards_list, context_list, treatments_list = ( + await format_rewards_for_arm_update( + experiment, draw.arm_id, 0.0, asession + ) ) - - total_failed += 1 - - await asession.commit() - offset += len(draws_batch) - - return total_failed - - -async def auto_fail_cmab(asession: AsyncSession) -> int: - """ - Auto fail experiments draws that have not been updated in a certain amount of time. - - Args: - asession: SQLAlchemy async session - - Returns: - int: Number of draws automatically failed - """ - total_failed = 0 - now = datetime.now(tz=timezone.utc) - - # Fetch all required experiments data in one query - experiment_query = select(ContextualBanditDB).where( - ContextualBanditDB.auto_fail.is_(True) - ) - experiments_result = (await asession.execute(experiment_query)).unique() - experiments = experiments_result.scalars().all() - for experiment in experiments: - hours_threshold = ( - experiment.auto_fail_value * 24 - if experiment.auto_fail_unit == "days" - else experiment.auto_fail_value - ) - - cutoff_datetime = now - timedelta(hours=hours_threshold) - - draws_query = ( - select(ContextualDrawDB) - .join( - ContextualBanditDB, - ContextualDrawDB.experiment_id == ContextualBanditDB.experiment_id, - ) - .where( - ContextualDrawDB.experiment_id == experiment.experiment_id, - ContextualDrawDB.observation_type.is_(None), - ContextualDrawDB.draw_datetime_utc <= cutoff_datetime, - ) - .limit(100) - ) # Process in smaller batches - - # Paginate through results if there are many draws to avoid memory issues - offset = 0 - while True: - batch_query = draws_query.offset(offset) - draws_result = (await asession.execute(batch_query)).unique() - draws_batch = draws_result.scalars().all() - - if not draws_batch: - break - - for draw in draws_batch: - draw.observation_type = ObservationType.AUTO - - await cmab_update_based_on_outcome( + await update_arm_based_on_outcome( experiment, draw, - 0.0, - asession, - ObservationType.AUTO, + rewards_list, + context_list, + treatments_list, ) total_failed += 1 @@ -238,12 +102,8 @@ async def main() -> None: Main function to process notifications """ async for asession in get_async_session(): - failed_count = await auto_fail_mab(asession) - print(f"Auto-failed MABs: {failed_count} draws") - failed_count = await auto_fail_cmab(asession) - print(f"Auto-failed CMABs: {failed_count} draws") - failed_count = await auto_fail_bayes_ab(asession) - print(f"Auto-failed Bayes ABs: {failed_count} draws") + failed_count = await auto_fail_experiment(asession) + print(f"Auto-failed experiments: {failed_count} draws") break diff --git a/backend/migrations/versions/275ff74c0866_add_client_id_to_draws_db.py b/backend/migrations/versions/275ff74c0866_add_client_id_to_draws_db.py deleted file mode 100644 index 02d31e3..0000000 --- a/backend/migrations/versions/275ff74c0866_add_client_id_to_draws_db.py +++ /dev/null @@ -1,30 +0,0 @@ -"""add client id to draws db - -Revision ID: 275ff74c0866 -Revises: 5c15463fda65 -Create Date: 2025-04-28 20:01:35.705717 - -""" - -from typing import Sequence, Union - -import sqlalchemy as sa -from alembic import op - -# revision identifiers, used by Alembic. -revision: str = "275ff74c0866" -down_revision: Union[str, None] = "5c15463fda65" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.add_column("draws_base", sa.Column("client_id", sa.String(), nullable=True)) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_column("draws_base", "client_id") - # ### end Alembic commands ### diff --git a/backend/migrations/versions/28adf347e68d_add_tables_for_bayesian_ab_experiments.py b/backend/migrations/versions/28adf347e68d_add_tables_for_bayesian_ab_experiments.py deleted file mode 100644 index 94ecd57..0000000 --- a/backend/migrations/versions/28adf347e68d_add_tables_for_bayesian_ab_experiments.py +++ /dev/null @@ -1,66 +0,0 @@ -"""add tables for Bayesian AB experiments - -Revision ID: 28adf347e68d -Revises: feb042798cad -Create Date: 2025-04-27 11:23:26.823140 - -""" - -from typing import Sequence, Union - -import sqlalchemy as sa -from alembic import op - -# revision identifiers, used by Alembic. -revision: str = "28adf347e68d" -down_revision: Union[str, None] = "feb042798cad" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table( - "bayes_ab_experiments", - sa.Column("experiment_id", sa.Integer(), nullable=False), - sa.ForeignKeyConstraint( - ["experiment_id"], ["experiments_base.experiment_id"], ondelete="CASCADE" - ), - sa.PrimaryKeyConstraint("experiment_id"), - ) - op.create_table( - "bayes_ab_arms", - sa.Column("arm_id", sa.Integer(), nullable=False), - sa.Column("mu_init", sa.Float(), nullable=False), - sa.Column("sigma_init", sa.Float(), nullable=False), - sa.Column("mu", sa.Float(), nullable=False), - sa.Column("sigma", sa.Float(), nullable=False), - sa.Column("is_treatment_arm", sa.Boolean(), nullable=False), - sa.ForeignKeyConstraint(["arm_id"], ["arms_base.arm_id"], ondelete="CASCADE"), - sa.PrimaryKeyConstraint("arm_id"), - ) - op.create_table( - "bayes_ab_draws", - sa.Column("draw_id", sa.String(), nullable=False), - sa.ForeignKeyConstraint( - ["draw_id"], ["draws_base.draw_id"], ondelete="CASCADE" - ), - sa.PrimaryKeyConstraint("draw_id"), - ) - op.add_column("mab_arms", sa.Column("alpha_init", sa.Float(), nullable=True)) - op.add_column("mab_arms", sa.Column("beta_init", sa.Float(), nullable=True)) - op.add_column("mab_arms", sa.Column("mu_init", sa.Float(), nullable=True)) - op.add_column("mab_arms", sa.Column("sigma_init", sa.Float(), nullable=True)) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_column("mab_arms", "sigma_init") - op.drop_column("mab_arms", "mu_init") - op.drop_column("mab_arms", "beta_init") - op.drop_column("mab_arms", "alpha_init") - op.drop_table("bayes_ab_draws") - op.drop_table("bayes_ab_arms") - op.drop_table("bayes_ab_experiments") - # ### end Alembic commands ### diff --git a/backend/migrations/versions/5c15463fda65_added_first_name_and_last_name_to_users.py b/backend/migrations/versions/5c15463fda65_added_first_name_and_last_name_to_users.py deleted file mode 100644 index 39b1a4d..0000000 --- a/backend/migrations/versions/5c15463fda65_added_first_name_and_last_name_to_users.py +++ /dev/null @@ -1,36 +0,0 @@ -"""added first name and last name to users - -Revision ID: 5c15463fda65 -Revises: 28adf347e68d -Create Date: 2025-04-26 15:47:23.199751 - -""" - -from typing import Sequence, Union - -import sqlalchemy as sa -from alembic import op - -# revision identifiers, used by Alembic. -revision: str = "5c15463fda65" -down_revision: Union[str, None] = "28adf347e68d" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # Add columns as nullable first - op.add_column("users", sa.Column("first_name", sa.String(), nullable=True)) - op.add_column("users", sa.Column("last_name", sa.String(), nullable=True)) - - # Set default values for existing records - op.execute("UPDATE users SET first_name = '', last_name = ''") - - # Make columns non-nullable - op.alter_column("users", "first_name", nullable=False) - op.alter_column("users", "last_name", nullable=False) - - -def downgrade() -> None: - op.drop_column("users", "last_name") - op.drop_column("users", "first_name") diff --git a/backend/migrations/versions/6101ba814d91_fresh_start.py b/backend/migrations/versions/6101ba814d91_fresh_start.py new file mode 100644 index 0000000..d246310 --- /dev/null +++ b/backend/migrations/versions/6101ba814d91_fresh_start.py @@ -0,0 +1,438 @@ +"""fresh start + +Revision ID: 6101ba814d91 +Revises: +Create Date: 2025-06-03 18:00:18.919218 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "6101ba814d91" +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "users", + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("username", sa.String(), nullable=False), + sa.Column("first_name", sa.String(), nullable=False), + sa.Column("last_name", sa.String(), nullable=False), + sa.Column("hashed_password", sa.String(length=96), nullable=False), + sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False), + sa.Column("updated_datetime_utc", sa.DateTime(timezone=True), nullable=False), + sa.Column("access_level", sa.String(), nullable=False), + sa.Column("is_active", sa.Boolean(), nullable=False), + sa.Column("is_verified", sa.Boolean(), nullable=False), + sa.PrimaryKeyConstraint("user_id"), + sa.UniqueConstraint("username"), + ) + op.create_table( + "messages", + sa.Column("message_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("text", sa.String(), nullable=False), + sa.Column("title", sa.String(), nullable=False), + sa.Column("is_unread", sa.Boolean(), nullable=False), + sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False), + sa.Column("message_type", sa.String(length=50), nullable=False), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.user_id"], + ), + sa.PrimaryKeyConstraint("message_id"), + ) + op.create_table( + "workspace", + sa.Column("api_daily_quota", sa.Integer(), nullable=True), + sa.Column("api_key_first_characters", sa.String(length=5), nullable=True), + sa.Column( + "api_key_updated_datetime_utc", sa.DateTime(timezone=True), nullable=True + ), + sa.Column("api_key_rotated_by_user_id", sa.Integer(), nullable=True), + sa.Column("content_quota", sa.Integer(), nullable=True), + sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False), + sa.Column("hashed_api_key", sa.String(length=96), nullable=True), + sa.Column("updated_datetime_utc", sa.DateTime(timezone=True), nullable=False), + sa.Column("workspace_id", sa.Integer(), nullable=False), + sa.Column("workspace_name", sa.String(), nullable=False), + sa.Column("is_default", sa.Boolean(), nullable=False), + sa.ForeignKeyConstraint( + ["api_key_rotated_by_user_id"], + ["users.user_id"], + ), + sa.PrimaryKeyConstraint("workspace_id"), + sa.UniqueConstraint("hashed_api_key"), + sa.UniqueConstraint("workspace_name"), + ) + op.create_table( + "api_key_rotation_history", + sa.Column("rotation_id", sa.Integer(), nullable=False), + sa.Column("workspace_id", sa.Integer(), nullable=False), + sa.Column("rotated_by_user_id", sa.Integer(), nullable=False), + sa.Column("key_first_characters", sa.String(length=5), nullable=False), + sa.Column("rotation_datetime_utc", sa.DateTime(timezone=True), nullable=False), + sa.ForeignKeyConstraint( + ["rotated_by_user_id"], + ["users.user_id"], + ), + sa.ForeignKeyConstraint( + ["workspace_id"], ["workspace.workspace_id"], ondelete="CASCADE" + ), + sa.PrimaryKeyConstraint("rotation_id"), + ) + op.create_table( + "experiments", + sa.Column("experiment_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("workspace_id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(length=150), nullable=False), + sa.Column("description", sa.String(length=500), nullable=False), + sa.Column("is_active", sa.Boolean(), nullable=False), + sa.Column("sticky_assignment", sa.Boolean(), nullable=False), + sa.Column("auto_fail", sa.Boolean(), nullable=False), + sa.Column("auto_fail_value", sa.Integer(), nullable=True), + sa.Column( + "auto_fail_unit", + sa.Enum("DAYS", "HOURS", name="autofailunittype"), + nullable=True, + ), + sa.Column("exp_type", sa.String(length=50), nullable=False), + sa.Column("prior_type", sa.String(length=50), nullable=False), + sa.Column("reward_type", sa.String(length=50), nullable=False), + sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False), + sa.Column("n_trials", sa.Integer(), nullable=False), + sa.Column("last_trial_datetime_utc", sa.DateTime(timezone=True), nullable=True), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.user_id"], + ), + sa.ForeignKeyConstraint( + ["workspace_id"], + ["workspace.workspace_id"], + ), + sa.PrimaryKeyConstraint("experiment_id"), + ) + op.create_table( + "experiments_base", + sa.Column("experiment_id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(length=150), nullable=False), + sa.Column("description", sa.String(length=500), nullable=False), + sa.Column("sticky_assignment", sa.Boolean(), nullable=False), + sa.Column("auto_fail", sa.Boolean(), nullable=False), + sa.Column("auto_fail_value", sa.Integer(), nullable=True), + sa.Column( + "auto_fail_unit", + sa.Enum("DAYS", "HOURS", name="autofailunittype"), + nullable=True, + ), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("workspace_id", sa.Integer(), nullable=False), + sa.Column("is_active", sa.Boolean(), nullable=False), + sa.Column("exp_type", sa.String(length=50), nullable=False), + sa.Column("prior_type", sa.String(length=50), nullable=False), + sa.Column("reward_type", sa.String(length=50), nullable=False), + sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False), + sa.Column("n_trials", sa.Integer(), nullable=False), + sa.Column("last_trial_datetime_utc", sa.DateTime(timezone=True), nullable=True), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.user_id"], + ), + sa.ForeignKeyConstraint( + ["workspace_id"], + ["workspace.workspace_id"], + ), + sa.PrimaryKeyConstraint("experiment_id"), + ) + op.create_table( + "pending_invitations", + sa.Column("invitation_id", sa.Integer(), nullable=False), + sa.Column("email", sa.String(), nullable=False), + sa.Column("workspace_id", sa.Integer(), nullable=False), + sa.Column( + "role", + sa.Enum("ADMIN", "READ_ONLY", name="userroles", native_enum=False), + nullable=False, + ), + sa.Column("inviter_id", sa.Integer(), nullable=False), + sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False), + sa.ForeignKeyConstraint( + ["inviter_id"], + ["users.user_id"], + ), + sa.ForeignKeyConstraint( + ["workspace_id"], ["workspace.workspace_id"], ondelete="CASCADE" + ), + sa.PrimaryKeyConstraint("invitation_id"), + ) + op.create_table( + "user_workspace", + sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False), + sa.Column( + "default_workspace", + sa.Boolean(), + server_default=sa.text("false"), + nullable=False, + ), + sa.Column("updated_datetime_utc", sa.DateTime(timezone=True), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column( + "user_role", + sa.Enum("ADMIN", "READ_ONLY", name="userroles", native_enum=False), + nullable=False, + ), + sa.Column("workspace_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(["user_id"], ["users.user_id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint( + ["workspace_id"], ["workspace.workspace_id"], ondelete="CASCADE" + ), + sa.PrimaryKeyConstraint("user_id", "workspace_id"), + ) + op.create_table( + "arms", + sa.Column("arm_id", sa.Integer(), nullable=False), + sa.Column("workspace_id", sa.Integer(), nullable=False), + sa.Column("experiment_id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(length=150), nullable=False), + sa.Column("description", sa.String(length=500), nullable=False), + sa.Column("n_outcomes", sa.Integer(), nullable=False), + sa.Column("mu_init", sa.Float(), nullable=True), + sa.Column("sigma_init", sa.Float(), nullable=True), + sa.Column("mu", postgresql.ARRAY(sa.Float()), nullable=True), + sa.Column("covariance", postgresql.ARRAY(sa.Float()), nullable=True), + sa.Column("is_treatment_arm", sa.Boolean(), nullable=True), + sa.Column("alpha_init", sa.Float(), nullable=True), + sa.Column("beta_init", sa.Float(), nullable=True), + sa.Column("alpha", sa.Float(), nullable=True), + sa.Column("beta", sa.Float(), nullable=True), + sa.ForeignKeyConstraint( + ["experiment_id"], + ["experiments.experiment_id"], + ), + sa.ForeignKeyConstraint( + ["workspace_id"], + ["workspace.workspace_id"], + ), + sa.PrimaryKeyConstraint("arm_id"), + ) + op.create_table( + "arms_base", + sa.Column("arm_id", sa.Integer(), nullable=False), + sa.Column("experiment_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(length=150), nullable=False), + sa.Column("description", sa.String(length=500), nullable=False), + sa.Column("arm_type", sa.String(length=50), nullable=False), + sa.Column("n_outcomes", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["experiment_id"], + ["experiments_base.experiment_id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.user_id"], + ), + sa.PrimaryKeyConstraint("arm_id"), + ) + op.create_table( + "clients", + sa.Column("client_id", sa.String(), nullable=False), + sa.Column("experiment_id", sa.Integer(), nullable=False), + sa.Column("workspace_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["experiment_id"], + ["experiments.experiment_id"], + ), + sa.ForeignKeyConstraint( + ["workspace_id"], + ["workspace.workspace_id"], + ), + sa.PrimaryKeyConstraint("client_id"), + ) + op.create_table( + "context", + sa.Column("context_id", sa.Integer(), nullable=False), + sa.Column("experiment_id", sa.Integer(), nullable=False), + sa.Column("workspace_id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(length=150), nullable=False), + sa.Column("description", sa.String(length=500), nullable=True), + sa.Column("value_type", sa.String(length=50), nullable=False), + sa.ForeignKeyConstraint( + ["experiment_id"], + ["experiments.experiment_id"], + ), + sa.ForeignKeyConstraint( + ["workspace_id"], + ["workspace.workspace_id"], + ), + sa.PrimaryKeyConstraint("context_id"), + ) + op.create_table( + "event_messages", + sa.Column("message_id", sa.Integer(), nullable=False), + sa.Column("experiment_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["experiment_id"], + ["experiments_base.experiment_id"], + ), + sa.ForeignKeyConstraint( + ["message_id"], ["messages.message_id"], ondelete="CASCADE" + ), + sa.PrimaryKeyConstraint("message_id"), + ) + op.create_table( + "notifications", + sa.Column("notification_id", sa.Integer(), nullable=False), + sa.Column("experiment_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("workspace_id", sa.Integer(), nullable=False), + sa.Column( + "notification_type", + sa.Enum( + "DAYS_ELAPSED", + "TRIALS_COMPLETED", + "PERCENTAGE_BETTER", + name="eventtype", + ), + nullable=False, + ), + sa.Column("notification_value", sa.Integer(), nullable=False), + sa.Column("is_active", sa.Boolean(), nullable=False), + sa.ForeignKeyConstraint( + ["experiment_id"], + ["experiments.experiment_id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.user_id"], + ), + sa.ForeignKeyConstraint( + ["workspace_id"], + ["workspace.workspace_id"], + ), + sa.PrimaryKeyConstraint("notification_id"), + ) + op.create_table( + "notifications_db", + sa.Column("notification_id", sa.Integer(), nullable=False), + sa.Column("experiment_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column( + "notification_type", + sa.Enum( + "DAYS_ELAPSED", + "TRIALS_COMPLETED", + "PERCENTAGE_BETTER", + name="eventtype", + ), + nullable=False, + ), + sa.Column("notification_value", sa.Integer(), nullable=False), + sa.Column("is_active", sa.Boolean(), nullable=False), + sa.ForeignKeyConstraint( + ["experiment_id"], + ["experiments_base.experiment_id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.user_id"], + ), + sa.PrimaryKeyConstraint("notification_id"), + ) + op.create_table( + "draws", + sa.Column("draw_id", sa.String(), nullable=False), + sa.Column("arm_id", sa.Integer(), nullable=False), + sa.Column("experiment_id", sa.Integer(), nullable=False), + sa.Column("workspace_id", sa.Integer(), nullable=False), + sa.Column("client_id", sa.String(length=36), nullable=True), + sa.Column("draw_datetime_utc", sa.DateTime(timezone=True), nullable=False), + sa.Column("observed_datetime_utc", sa.DateTime(timezone=True), nullable=True), + sa.Column( + "observation_type", + sa.Enum("USER", "AUTO", name="observationtype"), + nullable=True, + ), + sa.Column("reward", sa.Float(), nullable=True), + sa.Column("context_val", postgresql.ARRAY(sa.Float()), nullable=True), + sa.ForeignKeyConstraint( + ["arm_id"], + ["arms.arm_id"], + ), + sa.ForeignKeyConstraint( + ["client_id"], + ["clients.client_id"], + ), + sa.ForeignKeyConstraint( + ["experiment_id"], + ["experiments.experiment_id"], + ), + sa.ForeignKeyConstraint( + ["workspace_id"], + ["workspace.workspace_id"], + ), + sa.PrimaryKeyConstraint("draw_id"), + ) + op.create_table( + "draws_base", + sa.Column("draw_id", sa.String(), nullable=False), + sa.Column("client_id", sa.String(), nullable=True), + sa.Column("arm_id", sa.Integer(), nullable=False), + sa.Column("experiment_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("draw_datetime_utc", sa.DateTime(timezone=True), nullable=False), + sa.Column("observed_datetime_utc", sa.DateTime(timezone=True), nullable=True), + sa.Column( + "observation_type", + sa.Enum("USER", "AUTO", name="observationtype"), + nullable=True, + ), + sa.Column("draw_type", sa.String(length=50), nullable=False), + sa.Column("reward", sa.Float(), nullable=True), + sa.ForeignKeyConstraint( + ["arm_id"], + ["arms_base.arm_id"], + ), + sa.ForeignKeyConstraint( + ["experiment_id"], + ["experiments_base.experiment_id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.user_id"], + ), + sa.PrimaryKeyConstraint("draw_id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("draws_base") + op.drop_table("draws") + op.drop_table("notifications_db") + op.drop_table("notifications") + op.drop_table("event_messages") + op.drop_table("context") + op.drop_table("clients") + op.drop_table("arms_base") + op.drop_table("arms") + op.drop_table("user_workspace") + op.drop_table("pending_invitations") + op.drop_table("experiments_base") + op.drop_table("experiments") + op.drop_table("api_key_rotation_history") + op.drop_table("workspace") + op.drop_table("messages") + op.drop_table("users") + # ### end Alembic commands ### diff --git a/backend/migrations/versions/9f7482ba882f_workspace_model.py b/backend/migrations/versions/9f7482ba882f_workspace_model.py deleted file mode 100644 index 3543211..0000000 --- a/backend/migrations/versions/9f7482ba882f_workspace_model.py +++ /dev/null @@ -1,123 +0,0 @@ -"""Workspace model - -Revision ID: 9f7482ba882f -Revises: 275ff74c0866 -Create Date: 2025-05-04 11:56:03.939578 - -""" - -from typing import Sequence, Union - -import sqlalchemy as sa -from alembic import op - -# revision identifiers, used by Alembic. -revision: str = "9f7482ba882f" -down_revision: Union[str, None] = "275ff74c0866" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table( - "workspace", - sa.Column("api_daily_quota", sa.Integer(), nullable=True), - sa.Column("api_key_first_characters", sa.String(length=5), nullable=True), - sa.Column( - "api_key_updated_datetime_utc", sa.DateTime(timezone=True), nullable=True - ), - sa.Column("api_key_rotated_by_user_id", sa.Integer(), nullable=True), - sa.Column("content_quota", sa.Integer(), nullable=True), - sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False), - sa.Column("hashed_api_key", sa.String(length=96), nullable=True), - sa.Column("updated_datetime_utc", sa.DateTime(timezone=True), nullable=False), - sa.Column("workspace_id", sa.Integer(), nullable=False), - sa.Column("workspace_name", sa.String(), nullable=False), - sa.Column("is_default", sa.Boolean(), nullable=False), - sa.ForeignKeyConstraint( - ["api_key_rotated_by_user_id"], - ["users.user_id"], - ), - sa.PrimaryKeyConstraint("workspace_id"), - sa.UniqueConstraint("hashed_api_key"), - sa.UniqueConstraint("workspace_name"), - ) - op.create_table( - "api_key_rotation_history", - sa.Column("rotation_id", sa.Integer(), nullable=False), - sa.Column("workspace_id", sa.Integer(), nullable=False), - sa.Column("rotated_by_user_id", sa.Integer(), nullable=False), - sa.Column("key_first_characters", sa.String(length=5), nullable=False), - sa.Column("rotation_datetime_utc", sa.DateTime(timezone=True), nullable=False), - sa.ForeignKeyConstraint( - ["rotated_by_user_id"], - ["users.user_id"], - ), - sa.ForeignKeyConstraint( - ["workspace_id"], ["workspace.workspace_id"], ondelete="CASCADE" - ), - sa.PrimaryKeyConstraint("rotation_id"), - ) - op.create_table( - "pending_invitations", - sa.Column("invitation_id", sa.Integer(), nullable=False), - sa.Column("email", sa.String(), nullable=False), - sa.Column("workspace_id", sa.Integer(), nullable=False), - sa.Column( - "role", - sa.Enum("ADMIN", "READ_ONLY", name="userroles", native_enum=False), - nullable=False, - ), - sa.Column("inviter_id", sa.Integer(), nullable=False), - sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False), - sa.ForeignKeyConstraint( - ["inviter_id"], - ["users.user_id"], - ), - sa.ForeignKeyConstraint( - ["workspace_id"], ["workspace.workspace_id"], ondelete="CASCADE" - ), - sa.PrimaryKeyConstraint("invitation_id"), - ) - op.create_table( - "user_workspace", - sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False), - sa.Column( - "default_workspace", - sa.Boolean(), - server_default=sa.text("false"), - nullable=False, - ), - sa.Column("updated_datetime_utc", sa.DateTime(timezone=True), nullable=False), - sa.Column("user_id", sa.Integer(), nullable=False), - sa.Column( - "user_role", - sa.Enum("ADMIN", "READ_ONLY", name="userroles", native_enum=False), - nullable=False, - ), - sa.Column("workspace_id", sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(["user_id"], ["users.user_id"], ondelete="CASCADE"), - sa.ForeignKeyConstraint( - ["workspace_id"], ["workspace.workspace_id"], ondelete="CASCADE" - ), - sa.PrimaryKeyConstraint("user_id", "workspace_id"), - ) - op.add_column( - "experiments_base", sa.Column("workspace_id", sa.Integer(), nullable=False) - ) - op.create_foreign_key( - None, "experiments_base", "workspace", ["workspace_id"], ["workspace_id"] - ) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_constraint(None, "experiments_base", type_="foreignkey") - op.drop_column("experiments_base", "workspace_id") - op.drop_table("user_workspace") - op.drop_table("pending_invitations") - op.drop_table("api_key_rotation_history") - op.drop_table("workspace") - # ### end Alembic commands ### diff --git a/backend/migrations/versions/ecddd830b464_remove_user_api_key.py b/backend/migrations/versions/ecddd830b464_remove_user_api_key.py deleted file mode 100644 index b03b032..0000000 --- a/backend/migrations/versions/ecddd830b464_remove_user_api_key.py +++ /dev/null @@ -1,70 +0,0 @@ -"""Remove User API key - -Revision ID: ecddd830b464 -Revises: 9f7482ba882f -Create Date: 2025-05-21 13:59:22.199884 - -""" - -from typing import Sequence, Union - -import sqlalchemy as sa -from alembic import op -from sqlalchemy.dialects import postgresql - -# revision identifiers, used by Alembic. -revision: str = "ecddd830b464" -down_revision: Union[str, None] = "9f7482ba882f" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_constraint("users_hashed_api_key_key", "users", type_="unique") - op.drop_column("users", "api_daily_quota") - op.drop_column("users", "hashed_api_key") - op.drop_column("users", "api_key_updated_datetime_utc") - op.drop_column("users", "api_key_first_characters") - op.drop_column("users", "experiments_quota") - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.add_column( - "users", - sa.Column( - "experiments_quota", sa.INTEGER(), autoincrement=False, nullable=True - ), - ) - op.add_column( - "users", - sa.Column( - "api_key_first_characters", - sa.VARCHAR(length=5), - autoincrement=False, - nullable=False, - ), - ) - op.add_column( - "users", - sa.Column( - "api_key_updated_datetime_utc", - postgresql.TIMESTAMP(timezone=True), - autoincrement=False, - nullable=False, - ), - ) - op.add_column( - "users", - sa.Column( - "hashed_api_key", sa.VARCHAR(length=96), autoincrement=False, nullable=False - ), - ) - op.add_column( - "users", - sa.Column("api_daily_quota", sa.INTEGER(), autoincrement=False, nullable=True), - ) - op.create_unique_constraint("users_hashed_api_key_key", "users", ["hashed_api_key"]) - # ### end Alembic commands ### diff --git a/backend/migrations/versions/faf4228e13a3_clean_start.py b/backend/migrations/versions/faf4228e13a3_clean_start.py deleted file mode 100644 index 71af813..0000000 --- a/backend/migrations/versions/faf4228e13a3_clean_start.py +++ /dev/null @@ -1,257 +0,0 @@ -"""clean start - -Revision ID: faf4228e13a3 -Revises: -Create Date: 2025-04-17 21:18:03.761219 - -""" - -from typing import Sequence, Union - -import sqlalchemy as sa -from alembic import op -from sqlalchemy.dialects import postgresql - -# revision identifiers, used by Alembic. -revision: str = "faf4228e13a3" -down_revision: Union[str, None] = None -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table( - "users", - sa.Column("user_id", sa.Integer(), nullable=False), - sa.Column("username", sa.String(), nullable=False), - sa.Column("hashed_password", sa.String(length=96), nullable=False), - sa.Column("hashed_api_key", sa.String(length=96), nullable=False), - sa.Column("api_key_first_characters", sa.String(length=5), nullable=False), - sa.Column( - "api_key_updated_datetime_utc", sa.DateTime(timezone=True), nullable=False - ), - sa.Column("experiments_quota", sa.Integer(), nullable=True), - sa.Column("api_daily_quota", sa.Integer(), nullable=True), - sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False), - sa.Column("updated_datetime_utc", sa.DateTime(timezone=True), nullable=False), - sa.Column("access_level", sa.String(), nullable=False), - sa.Column("is_active", sa.Boolean(), nullable=False), - sa.Column("is_verified", sa.Boolean(), nullable=False), - sa.PrimaryKeyConstraint("user_id"), - sa.UniqueConstraint("hashed_api_key"), - sa.UniqueConstraint("username"), - ) - op.create_table( - "experiments_base", - sa.Column("experiment_id", sa.Integer(), nullable=False), - sa.Column("name", sa.String(length=150), nullable=False), - sa.Column("description", sa.String(length=500), nullable=False), - sa.Column("user_id", sa.Integer(), nullable=False), - sa.Column("is_active", sa.Boolean(), nullable=False), - sa.Column("exp_type", sa.String(length=50), nullable=False), - sa.Column("prior_type", sa.String(length=50), nullable=False), - sa.Column("reward_type", sa.String(length=50), nullable=False), - sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False), - sa.Column("n_trials", sa.Integer(), nullable=False), - sa.Column("last_trial_datetime_utc", sa.DateTime(timezone=True), nullable=True), - sa.ForeignKeyConstraint( - ["user_id"], - ["users.user_id"], - ), - sa.PrimaryKeyConstraint("experiment_id"), - ) - op.create_table( - "messages", - sa.Column("message_id", sa.Integer(), nullable=False), - sa.Column("user_id", sa.Integer(), nullable=False), - sa.Column("text", sa.String(), nullable=False), - sa.Column("title", sa.String(), nullable=False), - sa.Column("is_unread", sa.Boolean(), nullable=False), - sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False), - sa.Column("message_type", sa.String(length=50), nullable=False), - sa.ForeignKeyConstraint( - ["user_id"], - ["users.user_id"], - ), - sa.PrimaryKeyConstraint("message_id"), - ) - op.create_table( - "arms_base", - sa.Column("arm_id", sa.Integer(), nullable=False), - sa.Column("experiment_id", sa.Integer(), nullable=False), - sa.Column("user_id", sa.Integer(), nullable=False), - sa.Column("name", sa.String(length=150), nullable=False), - sa.Column("description", sa.String(length=500), nullable=False), - sa.Column("arm_type", sa.String(length=50), nullable=False), - sa.Column("n_outcomes", sa.Integer(), nullable=False), - sa.ForeignKeyConstraint( - ["experiment_id"], - ["experiments_base.experiment_id"], - ), - sa.ForeignKeyConstraint( - ["user_id"], - ["users.user_id"], - ), - sa.PrimaryKeyConstraint("arm_id"), - ) - op.create_table( - "contextual_mabs", - sa.Column("experiment_id", sa.Integer(), nullable=False), - sa.ForeignKeyConstraint( - ["experiment_id"], ["experiments_base.experiment_id"], ondelete="CASCADE" - ), - sa.PrimaryKeyConstraint("experiment_id"), - ) - op.create_table( - "event_messages", - sa.Column("message_id", sa.Integer(), nullable=False), - sa.Column("experiment_id", sa.Integer(), nullable=False), - sa.ForeignKeyConstraint( - ["experiment_id"], - ["experiments_base.experiment_id"], - ), - sa.ForeignKeyConstraint( - ["message_id"], ["messages.message_id"], ondelete="CASCADE" - ), - sa.PrimaryKeyConstraint("message_id"), - ) - op.create_table( - "mabs", - sa.Column("experiment_id", sa.Integer(), nullable=False), - sa.ForeignKeyConstraint( - ["experiment_id"], ["experiments_base.experiment_id"], ondelete="CASCADE" - ), - sa.PrimaryKeyConstraint("experiment_id"), - ) - op.create_table( - "notifications", - sa.Column("notification_id", sa.Integer(), nullable=False), - sa.Column("experiment_id", sa.Integer(), nullable=False), - sa.Column("user_id", sa.Integer(), nullable=False), - sa.Column( - "notification_type", - sa.Enum( - "DAYS_ELAPSED", - "TRIALS_COMPLETED", - "PERCENTAGE_BETTER", - name="eventtype", - ), - nullable=False, - ), - sa.Column("notification_value", sa.Integer(), nullable=False), - sa.Column("is_active", sa.Boolean(), nullable=False), - sa.ForeignKeyConstraint( - ["experiment_id"], - ["experiments_base.experiment_id"], - ), - sa.ForeignKeyConstraint( - ["user_id"], - ["users.user_id"], - ), - sa.PrimaryKeyConstraint("notification_id"), - ) - op.create_table( - "contexts", - sa.Column("context_id", sa.Integer(), nullable=False), - sa.Column("experiment_id", sa.Integer(), nullable=False), - sa.Column("user_id", sa.Integer(), nullable=False), - sa.Column("name", sa.String(length=150), nullable=False), - sa.Column("description", sa.String(length=500), nullable=True), - sa.Column("value_type", sa.String(length=50), nullable=False), - sa.ForeignKeyConstraint( - ["experiment_id"], - ["contextual_mabs.experiment_id"], - ), - sa.ForeignKeyConstraint( - ["user_id"], - ["users.user_id"], - ), - sa.PrimaryKeyConstraint("context_id"), - ) - op.create_table( - "contextual_arms", - sa.Column("arm_id", sa.Integer(), nullable=False), - sa.Column("mu_init", sa.Float(), nullable=False), - sa.Column("sigma_init", sa.Float(), nullable=False), - sa.Column("mu", postgresql.ARRAY(sa.Float()), nullable=False), - sa.Column("covariance", postgresql.ARRAY(sa.Float()), nullable=False), - sa.ForeignKeyConstraint(["arm_id"], ["arms_base.arm_id"], ondelete="CASCADE"), - sa.PrimaryKeyConstraint("arm_id"), - ) - op.create_table( - "draws_base", - sa.Column("draw_id", sa.String(), nullable=False), - sa.Column("arm_id", sa.Integer(), nullable=False), - sa.Column("experiment_id", sa.Integer(), nullable=False), - sa.Column("user_id", sa.Integer(), nullable=False), - sa.Column("draw_datetime_utc", sa.DateTime(timezone=True), nullable=False), - sa.Column("observed_datetime_utc", sa.DateTime(timezone=True), nullable=True), - sa.Column( - "observation_type", - sa.Enum("USER", "AUTO", name="observationtype"), - nullable=True, - ), - sa.Column("draw_type", sa.String(length=50), nullable=False), - sa.Column("reward", sa.Float(), nullable=True), - sa.ForeignKeyConstraint( - ["arm_id"], - ["arms_base.arm_id"], - ), - sa.ForeignKeyConstraint( - ["experiment_id"], - ["experiments_base.experiment_id"], - ), - sa.ForeignKeyConstraint( - ["user_id"], - ["users.user_id"], - ), - sa.PrimaryKeyConstraint("draw_id"), - ) - op.create_table( - "mab_arms", - sa.Column("arm_id", sa.Integer(), nullable=False), - sa.Column("alpha", sa.Float(), nullable=True), - sa.Column("beta", sa.Float(), nullable=True), - sa.Column("mu", sa.Float(), nullable=True), - sa.Column("sigma", sa.Float(), nullable=True), - sa.ForeignKeyConstraint(["arm_id"], ["arms_base.arm_id"], ondelete="CASCADE"), - sa.PrimaryKeyConstraint("arm_id"), - ) - op.create_table( - "contextual_draws", - sa.Column("draw_id", sa.String(), nullable=False), - sa.Column("context_val", postgresql.ARRAY(sa.Float()), nullable=False), - sa.ForeignKeyConstraint( - ["draw_id"], ["draws_base.draw_id"], ondelete="CASCADE" - ), - sa.PrimaryKeyConstraint("draw_id"), - ) - op.create_table( - "mab_draws", - sa.Column("draw_id", sa.String(), nullable=False), - sa.ForeignKeyConstraint( - ["draw_id"], ["draws_base.draw_id"], ondelete="CASCADE" - ), - sa.PrimaryKeyConstraint("draw_id"), - ) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_table("mab_draws") - op.drop_table("contextual_draws") - op.drop_table("mab_arms") - op.drop_table("draws_base") - op.drop_table("contextual_arms") - op.drop_table("contexts") - op.drop_table("notifications") - op.drop_table("mabs") - op.drop_table("event_messages") - op.drop_table("contextual_mabs") - op.drop_table("arms_base") - op.drop_table("messages") - op.drop_table("experiments_base") - op.drop_table("users") - # ### end Alembic commands ### diff --git a/backend/migrations/versions/feb042798cad_added_sticky_assignments_and_autofail.py b/backend/migrations/versions/feb042798cad_added_sticky_assignments_and_autofail.py deleted file mode 100644 index 824c2ba..0000000 --- a/backend/migrations/versions/feb042798cad_added_sticky_assignments_and_autofail.py +++ /dev/null @@ -1,59 +0,0 @@ -"""added sticky assignments and autofail - -Revision ID: feb042798cad -Revises: faf4228e13a3 -Create Date: 2025-04-18 15:11:40.688651 - -""" - -from typing import Sequence, Union - -import sqlalchemy as sa -from alembic import op - -# revision identifiers, used by Alembic. -revision: str = "feb042798cad" -down_revision: Union[str, None] = "faf4228e13a3" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - autofailunittype = sa.Enum( - "DAYS", - "HOURS", - name="autofailunittype", - ) - autofailunittype.create(op.get_bind()) - - op.add_column( - "experiments_base", sa.Column("sticky_assignment", sa.Boolean(), nullable=False) - ) - op.add_column( - "experiments_base", sa.Column("auto_fail", sa.Boolean(), nullable=False) - ) - op.add_column( - "experiments_base", sa.Column("auto_fail_value", sa.Integer(), nullable=True) - ) - op.add_column( - "experiments_base", - sa.Column( - "auto_fail_unit", - autofailunittype, - nullable=True, - ), - ) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_column("experiments_base", "auto_fail_unit") - op.drop_column("experiments_base", "auto_fail_value") - op.drop_column("experiments_base", "auto_fail") - op.drop_column("experiments_base", "sticky_assignment") - - sa.Enum(name="autofailunittype").drop(op.get_bind()) - - # ### end Alembic commands ### diff --git a/backend/tests/pytest.ini b/backend/tests/pytest.ini new file mode 100644 index 0000000..22932f8 --- /dev/null +++ b/backend/tests/pytest.ini @@ -0,0 +1,5 @@ +[pytest] +asyncio_mode = auto + +# Set the default fixture loop scope to function +asyncio_default_fixture_loop_scope = function diff --git a/deployment/docker-compose/docker-compose-dev.yml b/deployment/docker-compose/docker-compose-dev.yml index 6a7b5df..25db623 100644 --- a/deployment/docker-compose/docker-compose-dev.yml +++ b/deployment/docker-compose/docker-compose-dev.yml @@ -4,13 +4,17 @@ services: build: context: ../../backend dockerfile: Dockerfile + entrypoint: ["/bin/sh", "-c"] command: > - python -m alembic upgrade head && python add_users_to_db.py && uvicorn main:app --host 0.0.0.0 --port 8000 --reload + "python -m alembic upgrade head && + python add_users_to_db.py && + uvicorn main:app --host 0.0.0.0 --port 8000 --reload + " restart: always ports: - "8000:8000" volumes: - # - temp:/usr/src/experiment_engine_backend/temp + - temp:/usr/src/experiment_engine_backend/temp - ../../backend:/usr/src/experiment_engine_backend env_file: - .base.env @@ -65,7 +69,7 @@ services: redis: image: "redis:6.0-alpine" - ports: # Expose the port to port 6380 on the host machine for debugging + ports: - "6380:6379" restart: always @@ -73,4 +77,4 @@ volumes: db_volume: caddy_data: caddy_config: - # temp: + temp: diff --git a/deployment/docker-compose/docker-compose.yml b/deployment/docker-compose/docker-compose.yml index c326f49..307679c 100644 --- a/deployment/docker-compose/docker-compose.yml +++ b/deployment/docker-compose/docker-compose.yml @@ -57,7 +57,7 @@ services: redis: image: "redis:6.0-alpine" - ports: # Expose the port to port 6380 on the host machine for debugging + ports: - "6380:6379" restart: always diff --git a/frontend/src/app/(protected)/workspaces/page.tsx b/frontend/src/app/(protected)/workspaces/page.tsx index 2e7b6bc..872d091 100644 --- a/frontend/src/app/(protected)/workspaces/page.tsx +++ b/frontend/src/app/(protected)/workspaces/page.tsx @@ -151,8 +151,8 @@ export default function WorkspacesPage() { -
- +

API Configuration

API Key Prefix:
{currentWorkspace.api_key_first_characters}•••••
- +
Key Last Rotated:
{new Date(currentWorkspace.api_key_updated_datetime_utc).toLocaleDateString()}
@@ -220,7 +220,7 @@ export default function WorkspacesPage() {
-
- diff --git a/frontend/src/components/app-sidebar.tsx b/frontend/src/components/app-sidebar.tsx index 7295125..9293a21 100644 --- a/frontend/src/components/app-sidebar.tsx +++ b/frontend/src/components/app-sidebar.tsx @@ -6,7 +6,7 @@ import { Map, PieChart, Settings2, - FlaskConicalIcon, + FlaskConicalIcon } from "lucide-react"; import { NavMain } from "@/components/nav-main"; import { NavRecentExperiments } from "@/components/nav-recent-experiments"; @@ -22,39 +22,10 @@ import { import { apiCalls } from "@/utils/api"; import { useAuth } from "@/utils/auth"; -type UserDetails = { - username: string; - firstName: string; - lastName: string; - isActive: boolean; - isVerified: boolean; -}; - -const getUserDetails = async (token: string | null) => { - try { - if (token) { - const response = await apiCalls.getUser(token); - if (!response) { - throw new Error("No response from server"); - } - return { - username: response.username, - firstName: response.first_name, - lastName: response.last_name, - isActive: response.is_active, - isVerified: response.is_verified, - } as UserDetails; - } else { - throw new Error("No token provided"); - } -} catch (error: unknown) { - if (error instanceof Error) { - throw new Error(`Error fetching user details: ${error.message}`); - } else { - throw new Error("Error fetching user details"); - } - } -}; +const AppSidebar: React.FC> = React.memo(function AppSidebar({ + ...props +}) { + const { user, firstName, lastName} = useAuth(); // This is sample data. const data = { @@ -79,40 +50,33 @@ const data = { url: "#", icon: Settings2, }, - ], - recentExperiments: [ + ] +}; + + const recentExperiments = [ { - name: "New onboarding flows", + name: "Recent Experiment", url: "#", - icon: Frame, + icon: FlaskConicalIcon }, { name: "3 different voices", url: "#", - icon: PieChart, + icon: FlaskConicalIcon }, { name: "AI responses", url: "#", - icon: Map, - }, - ], -}; -const AppSidebar = React.memo(function AppSidebar({ - ...props -}: React.ComponentProps) { - const { token } = useAuth(); - const [userDetails, setUserDetails] = React.useState( - null - ); - - React.useEffect(() => { - if (token) { - getUserDetails(token) - .then((data) => setUserDetails(data)) - .catch((error) => console.error(error)); + icon: FlaskConicalIcon } - }, [token]); + ]; + + const userDetails = { + firstName: firstName || "?", + lastName: lastName || "?", + username: user || "loading" + }; + return ( @@ -120,7 +84,7 @@ const AppSidebar = React.memo(function AppSidebar({ - + diff --git a/frontend/src/components/ui/tabs.tsx b/frontend/src/components/ui/tabs.tsx index 8873b85..26eb109 100644 --- a/frontend/src/components/ui/tabs.tsx +++ b/frontend/src/components/ui/tabs.tsx @@ -52,4 +52,4 @@ const TabsContent = React.forwardRef< )) TabsContent.displayName = TabsPrimitive.Content.displayName -export { Tabs, TabsList, TabsTrigger, TabsContent } \ No newline at end of file +export { Tabs, TabsList, TabsTrigger, TabsContent } diff --git a/frontend/src/components/workspace-switcher.tsx b/frontend/src/components/workspace-switcher.tsx index 088cdfc..c3f93fb 100644 --- a/frontend/src/components/workspace-switcher.tsx +++ b/frontend/src/components/workspace-switcher.tsx @@ -73,6 +73,7 @@ export function WorkspaceSwitcher() { ); } + return ( diff --git a/frontend/src/utils/auth.tsx b/frontend/src/utils/auth.tsx index 46a243b..73c7419 100644 --- a/frontend/src/utils/auth.tsx +++ b/frontend/src/utils/auth.tsx @@ -83,7 +83,7 @@ const AuthProvider = ({ children }: AuthProviderProps) => { setIsVerified(userData.is_verified); setFirstName(userData.first_name); setLastName(userData.last_name); - + // Fetch current workspace await fetchCurrentWorkspace(); @@ -103,7 +103,7 @@ const AuthProvider = ({ children }: AuthProviderProps) => { const fetchCurrentWorkspace = async () => { if (!token) return; - + try { const workspaceData = await apiCalls.getCurrentWorkspace(token); setCurrentWorkspace(workspaceData); @@ -114,7 +114,7 @@ const AuthProvider = ({ children }: AuthProviderProps) => { const fetchWorkspaces = async () => { if (!token) return; - + try { const workspacesData = await apiCalls.getUserWorkspaces(token); setWorkspaces(workspacesData); @@ -125,18 +125,18 @@ const AuthProvider = ({ children }: AuthProviderProps) => { const switchWorkspace = async (workspaceName: string) => { if (!token) return; - + try { setIsLoading(true); const authResponse = await apiCalls.switchWorkspace(token, workspaceName); - + // Update token and other auth details localStorage.setItem("ee-token", authResponse.access_token); setToken(authResponse.access_token); - + // Refresh workspace data await fetchCurrentWorkspace(); - + return; } catch (error) { console.error("Error switching workspace:", error); @@ -148,14 +148,14 @@ const AuthProvider = ({ children }: AuthProviderProps) => { const rotateWorkspaceApiKey = async () => { if (!token) throw new Error("Not authenticated"); - + try { setIsLoading(true); const response = await apiCalls.rotateWorkspaceApiKey(token); - + // Update workspace to reflect key change await fetchCurrentWorkspace(); - + return response.new_api_key; } catch (error) { console.error("Error rotating workspace API key:", error); From d71746a897899b479f0a9e760417c5b73d1c2112 Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Fri, 6 Jun 2025 17:43:41 +0300 Subject: [PATCH 60/74] merge changes from tests --- backend/app/experiments/dependencies.py | 10 + backend/app/experiments/routers.py | 13 +- backend/app/experiments/sampling_utils.py | 2 - backend/app/messages/models.py | 2 +- backend/app/models.py | 243 +-- backend/app/schemas.py | 180 -- backend/jobs/auto_fail.py | 14 +- backend/jobs/create_notifications.py | 18 +- ...392_fix_messages_foreign_key_constraint.py | 41 + backend/tests/test_auto_fail.py | 284 +--- backend/tests/test_bayes_ab.py | 430 ----- backend/tests/test_cmabs.py | 452 ----- backend/tests/test_experiments.py | 604 +++++++ backend/tests/test_mabs.py | 463 ----- backend/tests/test_messages.py | 13 +- backend/tests/test_notifications_job.py | 83 +- frontend/package-lock.json | 1496 ++++++++++------- frontend/tsconfig.json | 24 +- 18 files changed, 1629 insertions(+), 2743 deletions(-) delete mode 100644 backend/app/schemas.py create mode 100644 backend/migrations/versions/45b9483ee392_fix_messages_foreign_key_constraint.py delete mode 100644 backend/tests/test_bayes_ab.py delete mode 100644 backend/tests/test_cmabs.py create mode 100644 backend/tests/test_experiments.py delete mode 100644 backend/tests/test_mabs.py diff --git a/backend/app/experiments/dependencies.py b/backend/app/experiments/dependencies.py index 249a989..75f6632 100644 --- a/backend/app/experiments/dependencies.py +++ b/backend/app/experiments/dependencies.py @@ -166,6 +166,8 @@ async def update_arm_based_on_outcome( rewards: list[float], contexts: Union[list[list[float]], None], treatments: Union[list[float], None], + observation_type: ObservationType, + asession: AsyncSession, ) -> ArmResponse: """ Update the arm parameters based on the outcome. @@ -191,6 +193,14 @@ async def update_arm_based_on_outcome( treatments=treatments, ) + await save_updated_data( + arm=experiment.arms[chosen_arm], + draw=draw, + reward=rewards[0], + observation_type=observation_type, + asession=asession, + ) + return ArmResponse.model_validate(arm) diff --git a/backend/app/experiments/routers.py b/backend/app/experiments/routers.py index 58431ad..feaba03 100644 --- a/backend/app/experiments/routers.py +++ b/backend/app/experiments/routers.py @@ -21,7 +21,6 @@ from .dependencies import ( experiments_db_to_schema, format_rewards_for_arm_update, - save_updated_data, update_arm_based_on_outcome, validate_experiment_and_draw, ) @@ -45,6 +44,7 @@ Experiment, ExperimentSample, ExperimentsEnum, + ObservationType, Outcome, ) @@ -431,17 +431,10 @@ async def update_experiment_arm( rewards=rewards_list, contexts=context_list, treatments=treatments_list, - ) - - observation_type = draw.observation_type - - await save_updated_data( - arm=experiment.arms[chosen_arm_index], - draw=draw, - reward=reward, - observation_type=observation_type, + observation_type=ObservationType.USER, asession=asession, ) + return ArmResponse.model_validate(experiment.arms[chosen_arm_index]) except Exception as e: raise HTTPException( diff --git a/backend/app/experiments/sampling_utils.py b/backend/app/experiments/sampling_utils.py index ac9db72..7ae8eca 100644 --- a/backend/app/experiments/sampling_utils.py +++ b/backend/app/experiments/sampling_utils.py @@ -137,7 +137,6 @@ def _update_arm_laplace( reward_likelihood : The likelihood function of the reward. prior_type : The prior type of the arm. """ - print(current_mu.shape, current_covariance.shape, reward.shape, context.shape) def objective(theta: np.ndarray) -> float: """ @@ -258,7 +257,6 @@ def update_arm( + [1.0] ) context = np.zeros((len(rewards), 3)) if not context else np.array(context) - print(rewards, treatments) context[:, 0] = np.array(treatments) context[:, 1] = 1.0 - np.array(treatments) context[:, 2] = 1.0 diff --git a/backend/app/messages/models.py b/backend/app/messages/models.py index 28ec3fb..61b557b 100644 --- a/backend/app/messages/models.py +++ b/backend/app/messages/models.py @@ -102,7 +102,7 @@ class EventMessageDB(MessageDB): nullable=False, ) experiment_id: Mapped[int] = mapped_column( - Integer, ForeignKey("experiments_base.experiment_id"), nullable=False + Integer, ForeignKey("experiments.experiment_id"), nullable=False ) __mapper_args__ = {"polymorphic_identity": "event"} diff --git a/backend/app/models.py b/backend/app/models.py index 77833b5..b8df8a0 100644 --- a/backend/app/models.py +++ b/backend/app/models.py @@ -1,21 +1,6 @@ -import uuid -from datetime import datetime -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING -from sqlalchemy import ( - Boolean, - DateTime, - Enum, - Float, - ForeignKey, - Integer, - String, - select, -) -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column - -from .schemas import AutoFailUnitType, EventType, Notifications, ObservationType +from sqlalchemy.orm import DeclarativeBase if TYPE_CHECKING: pass @@ -25,227 +10,3 @@ class Base(DeclarativeBase): """Base class for SQLAlchemy models""" pass - - -class ExperimentBaseDB(Base): - """ - Base model for experiments. - """ - - __tablename__ = "experiments_base" - - experiment_id: Mapped[int] = mapped_column( - Integer, primary_key=True, nullable=False - ) - name: Mapped[str] = mapped_column(String(length=150), nullable=False) - description: Mapped[str] = mapped_column(String(length=500), nullable=False) - sticky_assignment: Mapped[bool] = mapped_column( - Boolean, nullable=False, default=False - ) - auto_fail: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) - auto_fail_value: Mapped[int] = mapped_column(Integer, nullable=True) - auto_fail_unit: Mapped[AutoFailUnitType] = mapped_column( - Enum(AutoFailUnitType), nullable=True - ) - - user_id: Mapped[int] = mapped_column( - Integer, ForeignKey("users.user_id"), nullable=False - ) - workspace_id: Mapped[int] = mapped_column( - Integer, ForeignKey("workspace.workspace_id"), nullable=False - ) - is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) - exp_type: Mapped[str] = mapped_column(String(length=50), nullable=False) - prior_type: Mapped[str] = mapped_column(String(length=50), nullable=False) - reward_type: Mapped[str] = mapped_column(String(length=50), nullable=False) - - created_datetime_utc: Mapped[datetime] = mapped_column( - DateTime(timezone=True), nullable=False - ) - n_trials: Mapped[int] = mapped_column(Integer, nullable=False) - last_trial_datetime_utc: Mapped[datetime] = mapped_column( - DateTime(timezone=True), nullable=True - ) - - __mapper_args__ = { - "polymorphic_identity": "experiment", - "polymorphic_on": "exp_type", - } - - def __repr__(self) -> str: - """ - String representation of the model - """ - return f"" - - -class ArmBaseDB(Base): - """ - Base model for arms. - """ - - __tablename__ = "arms_base" - - arm_id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False) - experiment_id: Mapped[int] = mapped_column( - Integer, ForeignKey("experiments_base.experiment_id"), nullable=False - ) - user_id: Mapped[int] = mapped_column( - Integer, ForeignKey("users.user_id"), nullable=False - ) - - name: Mapped[str] = mapped_column(String(length=150), nullable=False) - description: Mapped[str] = mapped_column(String(length=500), nullable=False) - arm_type: Mapped[str] = mapped_column(String(length=50), nullable=False) - n_outcomes: Mapped[int] = mapped_column(Integer, nullable=False, default=0) - - __mapper_args__ = { - "polymorphic_identity": "arm", - "polymorphic_on": "arm_type", - } - - -class DrawsBaseDB(Base): - """ - Base model for draws. - """ - - __tablename__ = "draws_base" - - draw_id: Mapped[str] = mapped_column( - String, primary_key=True, default=lambda x: str(uuid.uuid4()) - ) - - client_id: Mapped[str] = mapped_column(String, nullable=True) - - arm_id: Mapped[int] = mapped_column( - Integer, ForeignKey("arms_base.arm_id"), nullable=False - ) - experiment_id: Mapped[int] = mapped_column( - Integer, ForeignKey("experiments_base.experiment_id"), nullable=False - ) - user_id: Mapped[int] = mapped_column( - Integer, ForeignKey("users.user_id"), nullable=False - ) - - draw_datetime_utc: Mapped[datetime] = mapped_column( - DateTime(timezone=True), - nullable=False, - ) - - observed_datetime_utc: Mapped[datetime] = mapped_column( - DateTime(timezone=True), nullable=True - ) - - observation_type: Mapped[ObservationType] = mapped_column( - Enum(ObservationType), nullable=True - ) - - draw_type: Mapped[str] = mapped_column(String(length=50), nullable=False) - - reward: Mapped[float] = mapped_column(Float, nullable=True) - - __mapper_args__ = { - "polymorphic_identity": "draw", - "polymorphic_on": "draw_type", - } - - -class NotificationsDB(Base): - """ - Model for notifications. - Note: if you are updating this, you should also update models in - the background celery job - """ - - __tablename__ = "notifications_db" - - notification_id: Mapped[int] = mapped_column( - Integer, primary_key=True, nullable=False - ) - experiment_id: Mapped[int] = mapped_column( - Integer, ForeignKey("experiments_base.experiment_id"), nullable=False - ) - user_id: Mapped[int] = mapped_column( - Integer, ForeignKey("users.user_id"), nullable=False - ) - notification_type: Mapped[EventType] = mapped_column( - Enum(EventType), nullable=False - ) - notification_value: Mapped[int] = mapped_column(Integer, nullable=False) - is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) - - def to_dict(self) -> dict: - """ - Convert the model to a dictionary - """ - return { - "notification_id": self.notification_id, - "experiment_id": self.experiment_id, - "user_id": self.user_id, - "notification_type": self.notification_type, - "notification_value": self.notification_value, - "is_active": self.is_active, - } - - -async def save_notifications_to_db( - experiment_id: int, - user_id: int, - notifications: Notifications, - asession: AsyncSession, -) -> list[NotificationsDB]: - """ - Save notifications to the database - """ - notification_records = [] - - if notifications.onTrialCompletion: - notification_row = NotificationsDB( - experiment_id=experiment_id, - user_id=user_id, - notification_type=EventType.TRIALS_COMPLETED, - notification_value=notifications.numberOfTrials, - is_active=True, - ) - notification_records.append(notification_row) - - if notifications.onDaysElapsed: - notification_row = NotificationsDB( - experiment_id=experiment_id, - user_id=user_id, - notification_type=EventType.DAYS_ELAPSED, - notification_value=notifications.daysElapsed, - is_active=True, - ) - notification_records.append(notification_row) - - if notifications.onPercentBetter: - notification_row = NotificationsDB( - experiment_id=experiment_id, - user_id=user_id, - notification_type=EventType.PERCENTAGE_BETTER, - notification_value=notifications.percentBetterThreshold, - is_active=True, - ) - notification_records.append(notification_row) - - asession.add_all(notification_records) - await asession.commit() - - return notification_records - - -async def get_notifications_from_db( - experiment_id: int, user_id: int, asession: AsyncSession -) -> Sequence[NotificationsDB]: - """ - Get notifications from the database - """ - statement = ( - select(NotificationsDB) - .where(NotificationsDB.experiment_id == experiment_id) - .where(NotificationsDB.user_id == user_id) - ) - - return (await asession.execute(statement)).scalars().all() diff --git a/backend/app/schemas.py b/backend/app/schemas.py deleted file mode 100644 index 783c52e..0000000 --- a/backend/app/schemas.py +++ /dev/null @@ -1,180 +0,0 @@ -from enum import Enum, StrEnum -from typing import Any, Self - -import numpy as np -from pydantic import BaseModel, ConfigDict, model_validator -from pydantic.types import NonNegativeInt - - -class EventType(StrEnum): - """Types of events that can trigger a notification""" - - DAYS_ELAPSED = "days_elapsed" - TRIALS_COMPLETED = "trials_completed" - PERCENTAGE_BETTER = "percentage_better" - - -class ObservationType(StrEnum): - """Types of observations that can be made""" - - USER = "user" # Generated by the user - AUTO = "auto" # Generated by the system - - -class AutoFailUnitType(StrEnum): - """Types of units for auto fail""" - - DAYS = "days" - HOURS = "hours" - - -class Notifications(BaseModel): - """ - Pydantic model for a notifications. - """ - - onTrialCompletion: bool = False - numberOfTrials: NonNegativeInt | None - onDaysElapsed: bool = False - daysElapsed: NonNegativeInt | None - onPercentBetter: bool = False - percentBetterThreshold: NonNegativeInt | None - - @model_validator(mode="after") - def validate_has_assocatiated_value(self) -> Self: - """ - Validate that the required corresponding fields have been set. - """ - if self.onTrialCompletion and ( - not self.numberOfTrials or self.numberOfTrials == 0 - ): - raise ValueError( - "numberOfTrials is required when onTrialCompletion is True" - ) - if self.onDaysElapsed and (not self.daysElapsed or self.daysElapsed == 0): - raise ValueError("daysElapsed is required when onDaysElapsed is True") - if self.onPercentBetter and ( - not self.percentBetterThreshold or self.percentBetterThreshold == 0 - ): - raise ValueError( - "percentBetterThreshold is required when onPercentBetter is True" - ) - - return self - - -class NotificationsResponse(BaseModel): - """ - Pydantic model for a response for notifications - """ - - model_config = ConfigDict(from_attributes=True) - - notification_id: int - notification_type: EventType - notification_value: NonNegativeInt - is_active: bool - - -class Outcome(float, Enum): - """ - Enum for the outcome of a trial. - """ - - SUCCESS = 1 - FAILURE = 0 - - -class ArmPriors(StrEnum): - """ - Enum for the prior distribution of the arm. - """ - - BETA = "beta" - NORMAL = "normal" - - def __call__(self, theta: np.ndarray, **kwargs: Any) -> np.ndarray: - """ - Return the log pdf of the input param. - """ - if self == ArmPriors.BETA: - alpha = kwargs.get("alpha", np.ones_like(theta)) - beta = kwargs.get("beta", np.ones_like(theta)) - return (alpha - 1) * np.log(theta) + (beta - 1) * np.log(1 - theta) - - elif self == ArmPriors.NORMAL: - mu = kwargs.get("mu", np.zeros_like(theta)) - covariance = kwargs.get("covariance", np.diag(np.ones_like(theta))) - inv_cov = np.linalg.inv(covariance) - x = theta - mu - return -0.5 * x @ inv_cov @ x - - -class RewardLikelihood(StrEnum): - """ - Enum for the likelihood distribution of the reward. - """ - - BERNOULLI = "binary" - NORMAL = "real-valued" - - def __call__(self, reward: np.ndarray, probs: np.ndarray) -> np.ndarray: - """ - Calculate the log likelihood of the reward. - - Parameters - ---------- - reward : The reward. - probs : The probability of the reward. - """ - if self == RewardLikelihood.NORMAL: - return -0.5 * np.sum((reward - probs) ** 2) - elif self == RewardLikelihood.BERNOULLI: - return np.sum(reward * np.log(probs) + (1 - reward) * np.log(1 - probs)) - - -class ContextType(StrEnum): - """ - Enum for the type of context. - """ - - BINARY = "binary" - REAL_VALUED = "real-valued" - - -class ContextLinkFunctions(StrEnum): - """ - Enum for the link function of the arm params and context. - """ - - NONE = "none" - LOGISTIC = "logistic" - - def __call__(self, x: np.ndarray) -> np.ndarray: - """ - Apply the link function to the input param. - - Parameters - ---------- - x : The input param. - """ - if self == ContextLinkFunctions.NONE: - return x - elif self == ContextLinkFunctions.LOGISTIC: - return 1.0 / (1.0 + np.exp(-x)) - - -allowed_combos_mab = [ - (ArmPriors.BETA, RewardLikelihood.BERNOULLI), - (ArmPriors.NORMAL, RewardLikelihood.NORMAL), -] - -allowed_combos_cmab = [ - (ArmPriors.NORMAL, RewardLikelihood.BERNOULLI), - (ArmPriors.NORMAL, RewardLikelihood.NORMAL), -] - -allowed_combos_bayes_ab = [ - (ArmPriors.NORMAL, RewardLikelihood.BERNOULLI), - (ArmPriors.NORMAL, RewardLikelihood.NORMAL), -] diff --git a/backend/jobs/auto_fail.py b/backend/jobs/auto_fail.py index 2323a56..ebf8231 100644 --- a/backend/jobs/auto_fail.py +++ b/backend/jobs/auto_fail.py @@ -78,15 +78,17 @@ async def auto_fail_experiment(asession: AsyncSession) -> int: rewards_list, context_list, treatments_list = ( await format_rewards_for_arm_update( - experiment, draw.arm_id, 0.0, asession + experiment, draw.arm_id, 0.0, draw.context_val, asession ) ) await update_arm_based_on_outcome( - experiment, - draw, - rewards_list, - context_list, - treatments_list, + experiment=experiment, + draw=draw, + rewards=rewards_list, + contexts=context_list, + treatments=treatments_list, + observation_type=ObservationType.AUTO, + asession=asession, ) total_failed += 1 diff --git a/backend/jobs/create_notifications.py b/backend/jobs/create_notifications.py index bb50508..f35adee 100644 --- a/backend/jobs/create_notifications.py +++ b/backend/jobs/create_notifications.py @@ -16,9 +16,9 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.database import get_async_session +from app.experiments.models import ExperimentDB, NotificationsDB +from app.experiments.schemas import EventType from app.messages.models import EventMessageDB -from app.models import ExperimentBaseDB, NotificationsDB -from app.schemas import EventType from app.utils import setup_logger logger = setup_logger(log_level=logging.INFO) @@ -34,10 +34,10 @@ async def check_days_elapsed( Check if the number of days elapsed since the experiment was created is greater than or equal to the milestone """ - experiments_stmt = select(ExperimentBaseDB).where( - ExperimentBaseDB.experiment_id == experiment_id + experiments_stmt = select(ExperimentDB).where( + ExperimentDB.experiment_id == experiment_id ) - experiment: ExperimentBaseDB | None = ( + experiment: ExperimentDB | None = ( (await asession.execute(experiments_stmt)).scalars().first() ) @@ -100,12 +100,8 @@ async def check_trials_completed( or equal to the milestone. """ # Fetch experiment - stmt = select(ExperimentBaseDB).where( - ExperimentBaseDB.experiment_id == experiment_id - ) - experiment: ExperimentBaseDB | None = ( - (await asession.execute(stmt)).scalars().first() - ) + stmt = select(ExperimentDB).where(ExperimentDB.experiment_id == experiment_id) + experiment: ExperimentDB | None = (await asession.execute(stmt)).scalars().first() if experiment: if experiment.n_trials >= milestone_trials: diff --git a/backend/migrations/versions/45b9483ee392_fix_messages_foreign_key_constraint.py b/backend/migrations/versions/45b9483ee392_fix_messages_foreign_key_constraint.py new file mode 100644 index 0000000..cdfcd4c --- /dev/null +++ b/backend/migrations/versions/45b9483ee392_fix_messages_foreign_key_constraint.py @@ -0,0 +1,41 @@ +"""fix messages foreign key constraint + +Revision ID: 45b9483ee392 +Revises: 6101ba814d91 +Create Date: 2025-06-05 18:10:33.744331 + +""" + +from typing import Sequence, Union + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "45b9483ee392" +down_revision: Union[str, None] = "6101ba814d91" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint( + op.f("event_messages_experiment_id_fkey"), "event_messages", type_="foreignkey" + ) + op.create_foreign_key( + None, "event_messages", "experiments", ["experiment_id"], ["experiment_id"] + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, "event_messages", type_="foreignkey") + op.create_foreign_key( + op.f("event_messages_experiment_id_fkey"), + "event_messages", + "experiments_base", + ["experiment_id"], + ["experiment_id"], + ) + # ### end Alembic commands ### diff --git a/backend/tests/test_auto_fail.py b/backend/tests/test_auto_fail.py index 6a91def..e8ae0b5 100644 --- a/backend/tests/test_auto_fail.py +++ b/backend/tests/test_auto_fail.py @@ -6,14 +6,13 @@ from pytest import FixtureRequest, MonkeyPatch, fixture, mark from sqlalchemy.ext.asyncio import AsyncSession -from backend.app.bayes_ab import models as bayes_ab_models -from backend.app.contextual_mab import models as cmab_models -from backend.app.mab import models as mab_models -from backend.jobs.auto_fail import auto_fail_bayes_ab, auto_fail_cmab, auto_fail_mab +from backend.app.experiments import models +from backend.jobs.auto_fail import auto_fail_experiment -base_mab_payload = { +base_experiment_payload = { "name": "Test AUTO FAIL", "description": "Test AUTO FAIL description", + "exp_type": "mab", "prior_type": "beta", "reward_type": "binary", "auto_fail": True, @@ -41,84 +40,8 @@ "onPercentBetter": False, "percentBetterThreshold": 5, }, -} - -base_cmab_payload = { - "name": "Test", - "description": "Test description", - "prior_type": "normal", - "reward_type": "real-valued", - "auto_fail": True, - "auto_fail_value": 3, - "auto_fail_unit": "hours", - "arms": [ - { - "name": "arm 1", - "description": "arm 1 description", - "mu_init": 0, - "sigma_init": 1, - }, - { - "name": "arm 2", - "description": "arm 2 description", - "mu_init": 0, - "sigma_init": 1, - }, - ], - "contexts": [ - { - "name": "Context 1", - "description": "context 1 description", - "value_type": "binary", - }, - { - "name": "Context 2", - "description": "context 2 description", - "value_type": "real-valued", - }, - ], - "notifications": { - "onTrialCompletion": True, - "numberOfTrials": 2, - "onDaysElapsed": False, - "daysElapsed": 3, - "onPercentBetter": False, - "percentBetterThreshold": 5, - }, -} - -base_ab_payload = { - "name": "Test", - "description": "Test description", - "prior_type": "normal", - "reward_type": "real-valued", - "auto_fail": True, - "auto_fail_value": 3, - "auto_fail_unit": "hours", - "arms": [ - { - "name": "arm 1", - "description": "arm 1 description", - "mu_init": 0, - "sigma_init": 1, - "is_treatment_arm": True, - }, - { - "name": "arm 2", - "description": "arm 2 description", - "mu_init": 2, - "sigma_init": 2, - "is_treatment_arm": False, - }, - ], - "notifications": { - "onTrialCompletion": True, - "numberOfTrials": 2, - "onDaysElapsed": False, - "daysElapsed": 3, - "onPercentBetter": False, - "percentBetterThreshold": 5, - }, + "contexts": [], + "clients": [], } @@ -131,201 +54,47 @@ def now(cls, *arg: list) -> datetime: return mydatetime -class TestMABAutoFailJob: - @fixture - def create_mab_with_autofail( - self, - client: TestClient, - admin_token: str, - request: FixtureRequest, - ) -> Generator: - auto_fail_value, auto_fail_unit = request.param - mab_payload = copy.deepcopy(base_mab_payload) - mab_payload["auto_fail_value"] = auto_fail_value - mab_payload["auto_fail_unit"] = auto_fail_unit - - headers = {"Authorization": f"Bearer {admin_token}"} - response = client.post( - "/mab", - json=mab_payload, - headers=headers, - ) - assert response.status_code == 200 - mab = response.json() - yield mab - headers = {"Authorization": f"Bearer {admin_token}"} - client.delete(f"/mab/{mab['experiment_id']}", headers=headers) - - @mark.parametrize( - "create_mab_with_autofail, fail_value, fail_unit, n_observed", - [ - ((12, "hours"), 12, "hours", 2), - ((10, "days"), 10, "days", 3), - ((3, "hours"), 3, "hours", 0), - ((5, "days"), 5, "days", 0), - ], - indirect=["create_mab_with_autofail"], - ) - async def test_auto_fail_job( - self, - client: TestClient, - admin_token: str, - monkeypatch: MonkeyPatch, - create_mab_with_autofail: dict, - fail_value: int, - fail_unit: Literal["days", "hours"], - n_observed: int, - asession: AsyncSession, - workspace_api_key: str, - ) -> None: - draws = [] - headers = {"Authorization": f"Bearer {workspace_api_key}"} - for i in range(1, 15): - monkeypatch.setattr( - mab_models, - "datetime", - fake_datetime( - days=i if fail_unit == "days" else 0, - hours=i if fail_unit == "hours" else 0, - ), - ) - response = client.get( - f"/mab/{create_mab_with_autofail['experiment_id']}/draw", - headers=headers, - ) - assert response.status_code == 200 - draws.append(response.json()["draw_id"]) - - if i >= (15 - n_observed): - response = client.put( - f"/mab/{create_mab_with_autofail['experiment_id']}/{draws[-1]}/1", - headers=headers, - ) - assert response.status_code == 200 - - n_failed = await auto_fail_mab(asession=asession) - - assert n_failed == (15 - fail_value - n_observed) - - -class TestBayesABAutoFailJob: - @fixture - def create_bayes_ab_with_autofail( - self, - client: TestClient, - admin_token: str, - request: FixtureRequest, - ) -> Generator: - auto_fail_value, auto_fail_unit = request.param - ab_payload = copy.deepcopy(base_ab_payload) - ab_payload["auto_fail_value"] = auto_fail_value - ab_payload["auto_fail_unit"] = auto_fail_unit - - headers = {"Authorization": f"Bearer {admin_token}"} - response = client.post( - "/bayes_ab", - json=ab_payload, - headers=headers, - ) - assert response.status_code == 200 - ab = response.json() - yield ab - headers = {"Authorization": f"Bearer {admin_token}"} - client.delete(f"/bayes_ab/{ab['experiment_id']}", headers=headers) - - @mark.parametrize( - "create_bayes_ab_with_autofail, fail_value, fail_unit, n_observed", - [ - ((12, "hours"), 12, "hours", 2), - ((10, "days"), 10, "days", 3), - ((3, "hours"), 3, "hours", 0), - ((5, "days"), 5, "days", 0), - ], - indirect=["create_bayes_ab_with_autofail"], - ) - async def test_auto_fail_job( - self, - client: TestClient, - admin_token: str, - monkeypatch: MonkeyPatch, - create_bayes_ab_with_autofail: dict, - fail_value: int, - fail_unit: Literal["days", "hours"], - n_observed: int, - asession: AsyncSession, - workspace_api_key: str, - ) -> None: - draws = [] - headers = {"Authorization": f"Bearer {workspace_api_key}"} - for i in range(1, 15): - monkeypatch.setattr( - bayes_ab_models, - "datetime", - fake_datetime( - days=i if fail_unit == "days" else 0, - hours=i if fail_unit == "hours" else 0, - ), - ) - response = client.get( - f"/bayes_ab/{create_bayes_ab_with_autofail['experiment_id']}/draw", - headers=headers, - ) - assert response.status_code == 200 - draws.append(response.json()["draw_id"]) - - if i >= (15 - n_observed): - response = client.put( - f"/bayes_ab/{create_bayes_ab_with_autofail['experiment_id']}/{draws[-1]}/1", - headers=headers, - ) - assert response.status_code == 200 - - n_failed = await auto_fail_bayes_ab(asession=asession) - - assert n_failed == (15 - fail_value - n_observed) - - -class TestCMABAutoFailJob: +class TestExperimentAutoFailJob: @fixture - def create_cmab_with_autofail( + def create_experiment_with_autofail( self, client: TestClient, admin_token: str, request: FixtureRequest, ) -> Generator: auto_fail_value, auto_fail_unit = request.param - cmab_payload = copy.deepcopy(base_cmab_payload) - cmab_payload["auto_fail_value"] = auto_fail_value - cmab_payload["auto_fail_unit"] = auto_fail_unit + experiment_payload = copy.deepcopy(base_experiment_payload) + experiment_payload["auto_fail_value"] = auto_fail_value + experiment_payload["auto_fail_unit"] = auto_fail_unit headers = {"Authorization": f"Bearer {admin_token}"} response = client.post( - "/contextual_mab", - json=cmab_payload, + "/experiment", + json=experiment_payload, headers=headers, ) assert response.status_code == 200 - cmab = response.json() - yield cmab + experiment = response.json() + yield experiment headers = {"Authorization": f"Bearer {admin_token}"} - client.delete(f"/contextual_mab/{cmab['experiment_id']}", headers=headers) + client.delete(f"/experiment/id/{experiment['experiment_id']}", headers=headers) @mark.parametrize( - "create_cmab_with_autofail, fail_value, fail_unit, n_observed", + "create_experiment_with_autofail, fail_value, fail_unit, n_observed", [ ((12, "hours"), 12, "hours", 2), ((10, "days"), 10, "days", 3), ((3, "hours"), 3, "hours", 0), ((5, "days"), 5, "days", 0), ], - indirect=["create_cmab_with_autofail"], + indirect=["create_experiment_with_autofail"], ) async def test_auto_fail_job( self, client: TestClient, admin_token: str, monkeypatch: MonkeyPatch, - create_cmab_with_autofail: dict, + create_experiment_with_autofail: dict, fail_value: int, fail_unit: Literal["days", "hours"], n_observed: int, @@ -336,19 +105,15 @@ async def test_auto_fail_job( headers = {"Authorization": f"Bearer {workspace_api_key}"} for i in range(1, 15): monkeypatch.setattr( - cmab_models, + models, "datetime", fake_datetime( days=i if fail_unit == "days" else 0, hours=i if fail_unit == "hours" else 0, ), ) - response = client.post( - f"/contextual_mab/{create_cmab_with_autofail['experiment_id']}/draw", - json=[ - {"context_id": 1, "context_value": 0}, - {"context_id": 2, "context_value": 0}, - ], + response = client.put( + f"/experiment/{create_experiment_with_autofail['experiment_id']}/draw", headers=headers, ) assert response.status_code == 200 @@ -356,11 +121,12 @@ async def test_auto_fail_job( if i >= (15 - n_observed): response = client.put( - f"/contextual_mab/{create_cmab_with_autofail['experiment_id']}/{draws[-1]}/1", + f"/experiment/{create_experiment_with_autofail['experiment_id']}/{draws[-1]}/1", headers=headers, ) + print(response.json()) assert response.status_code == 200 - n_failed = await auto_fail_cmab(asession=asession) + n_failed = await auto_fail_experiment(asession=asession) assert n_failed == (15 - fail_value - n_observed) diff --git a/backend/tests/test_bayes_ab.py b/backend/tests/test_bayes_ab.py deleted file mode 100644 index 6b75b6f..0000000 --- a/backend/tests/test_bayes_ab.py +++ /dev/null @@ -1,430 +0,0 @@ -import copy -import os -from typing import Generator - -import numpy as np -from fastapi.testclient import TestClient -from pytest import FixtureRequest, fixture, mark -from sqlalchemy.orm import Session - -from backend.app.bayes_ab.models import BayesianABArmDB, BayesianABDB -from backend.app.models import NotificationsDB - -base_normal_payload = { - "name": "Test", - "description": "Test description", - "prior_type": "normal", - "reward_type": "real-valued", - "sticky_assignment": False, - "arms": [ - { - "name": "arm 1", - "description": "arm 1 description", - "mu_init": 0, - "sigma_init": 1, - "is_treatment_arm": True, - }, - { - "name": "arm 2", - "description": "arm 2 description", - "mu_init": 2, - "sigma_init": 2, - "is_treatment_arm": False, - }, - ], - "notifications": { - "onTrialCompletion": True, - "numberOfTrials": 2, - "onDaysElapsed": False, - "daysElapsed": 3, - "onPercentBetter": False, - "percentBetterThreshold": 5, - }, -} - -base_binary_normal_payload = base_normal_payload.copy() -base_binary_normal_payload["reward_type"] = "binary" - - -@fixture -def clean_bayes_ab(db_session: Session) -> Generator: - """ - Fixture to clean the database before each test. - """ - yield - db_session.query(NotificationsDB).delete() - db_session.query(BayesianABArmDB).delete() - db_session.query(BayesianABDB).delete() - - db_session.commit() - - -@fixture -def admin_token(client: TestClient) -> str: - """Get a token for the admin user""" - response = client.post( - "/login", - data={ - "username": os.environ.get("ADMIN_USERNAME", "admin@idinsight.org"), - "password": os.environ.get("ADMIN_PASSWORD", "12345"), - }, - ) - assert response.status_code == 200, f"Login failed: {response.json()}" - token = response.json()["access_token"] - return token - - -class TestBayesAB: - """ - Test class for Bayesian A/B testing. - """ - - @fixture - def create_bayes_ab_payload(self, request: FixtureRequest) -> dict: - """ - Fixture to create a payload for the Bayesian A/B test. - """ - payload_normal: dict = copy.deepcopy(base_normal_payload) - payload_normal["arms"] = list(payload_normal["arms"]) - - payload_binary_normal: dict = copy.deepcopy(base_binary_normal_payload) - payload_binary_normal["arms"] = list(payload_binary_normal["arms"]) - - if request.param == "base_normal": - return payload_normal - if request.param == "base_binary_normal": - return payload_binary_normal - if request.param == "one_arm": - payload_normal["arms"].pop() - return payload_normal - if request.param == "no_notifications": - payload_normal["notifications"]["onTrialCompletion"] = False - return payload_normal - if request.param == "invalid_prior": - payload_normal["prior_type"] = "beta" - return payload_normal - if request.param == "invalid_sigma": - payload_normal["arms"][0]["sigma_init"] = 0 - return payload_normal - if request.param == "invalid_params": - payload_normal["arms"][0].pop("mu_init") - return payload_normal - if request.param == "two_treatment_arms": - payload_normal["arms"][0]["is_treatment_arm"] = True - payload_normal["arms"][1]["is_treatment_arm"] = True - return payload_normal - if request.param == "with_sticky_assignment": - payload_normal["sticky_assignment"] = True - return payload_normal - else: - raise ValueError("Invalid parameter") - - @fixture - def create_bayes_abs( - self, - client: TestClient, - admin_token: str, - create_bayes_ab_payload: dict, - request: FixtureRequest, - ) -> Generator: - bayes_abs = [] - n_bayes_abs = request.param if hasattr(request, "param") else 1 - for _ in range(n_bayes_abs): - response = client.post( - "/bayes_ab", - json=create_bayes_ab_payload, - headers={"Authorization": f"Bearer {admin_token}"}, - ) - bayes_abs.append(response.json()) - yield bayes_abs - for bayes_ab in bayes_abs: - client.delete( - f"/bayes_ab/{bayes_ab['experiment_id']}", - headers={"Authorization": f"Bearer {admin_token}"}, - ) - - @mark.parametrize( - "create_bayes_ab_payload, expected_response", - [ - ("base_normal", 200), - ("base_binary_normal", 200), - ("one_arm", 422), - ("no_notifications", 200), - ("invalid_prior", 422), - ("invalid_sigma", 422), - ("invalid_params", 200), - ("two_treatment_arms", 422), - ], - indirect=["create_bayes_ab_payload"], - ) - def test_create_bayes_ab( - self, - create_bayes_ab_payload: dict, - client: TestClient, - expected_response: int, - admin_token: str, - clean_bayes_ab: None, - ) -> None: - """ - Test the creation of a Bayesian A/B test. - """ - response = client.post( - "/bayes_ab", - json=create_bayes_ab_payload, - headers={"Authorization": f"Bearer {admin_token}"}, - ) - - assert response.status_code == expected_response - - @mark.parametrize( - "create_bayes_abs, n_expected, create_bayes_ab_payload", - [(1, 1, "base_normal"), (2, 2, "base_normal"), (5, 5, "base_normal")], - indirect=["create_bayes_abs", "create_bayes_ab_payload"], - ) - def test_get_bayes_abs( - self, - client: TestClient, - n_expected: int, - admin_token: str, - create_bayes_abs: list, - create_bayes_ab_payload: dict, - ) -> None: - """ - Test the retrieval of Bayesian A/B tests. - """ - response = client.get( - "/bayes_ab", headers={"Authorization": f"Bearer {admin_token}"} - ) - - assert response.status_code == 200 - assert len(response.json()) == n_expected - - @mark.parametrize( - "create_bayes_abs, expected_response, create_bayes_ab_payload", - [(1, 200, "base_normal"), (2, 200, "base_normal"), (5, 200, "base_normal")], - indirect=["create_bayes_abs", "create_bayes_ab_payload"], - ) - def test_draw_arm( - self, - client: TestClient, - create_bayes_abs: list, - create_bayes_ab_payload: dict, - expected_response: int, - workspace_api_key: str, - ) -> None: - id = create_bayes_abs[0]["experiment_id"] - response = client.get( - f"/bayes_ab/{id}/draw", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - assert response.status_code == expected_response - - @mark.parametrize( - "create_bayes_ab_payload, client_id, expected_response", - [ - ("with_sticky_assignment", None, 400), - ("with_sticky_assignment", "test_client_id", 200), - ], - indirect=["create_bayes_ab_payload"], - ) - def test_draw_arm_with_client_id( - self, - client: TestClient, - create_bayes_abs: list, - create_bayes_ab_payload: dict, - client_id: str | None, - expected_response: int, - workspace_api_key: str, - ) -> None: - id = create_bayes_abs[0]["experiment_id"] - response = client.get( - f"/bayes_ab/{id}/draw{'?client_id=' + client_id if client_id else ''}", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - assert response.status_code == expected_response - - @mark.parametrize( - "create_bayes_ab_payload", ["with_sticky_assignment"], indirect=True - ) - def test_draw_arm_with_sticky_assignment( - self, - client: TestClient, - create_bayes_abs: list, - create_bayes_ab_payload: dict, - workspace_api_key: str, - ) -> None: - id = create_bayes_abs[0]["experiment_id"] - arm_ids = [] - for _ in range(10): - response = client.get( - f"/bayes_ab/{id}/draw?client_id=123", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - arm_ids.append(response.json()["arm"]["arm_id"]) - assert np.unique(arm_ids).size == 1 - - @mark.parametrize("create_bayes_ab_payload", ["base_normal"], indirect=True) - def test_update_observation( - self, - client: TestClient, - create_bayes_abs: list, - create_bayes_ab_payload: dict, - workspace_api_key: str, - ) -> None: - id = create_bayes_abs[0]["experiment_id"] - - # First, get a draw - response = client.get( - f"/bayes_ab/{id}/draw", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - assert response.status_code == 200 - draw_id = response.json()["draw_id"] - - # Then update with an observation - response = client.put( - f"/bayes_ab/{id}/{draw_id}/0.5", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - assert response.status_code == 200 - - # Test that we can't update the same draw twice - response = client.put( - f"/bayes_ab/{id}/{draw_id}/0.5", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - assert response.status_code == 400 - - @mark.parametrize("create_bayes_ab_payload", ["base_normal"], indirect=True) - def test_get_outcomes( - self, - client: TestClient, - create_bayes_abs: list, - create_bayes_ab_payload: dict, - workspace_api_key: str, - ) -> None: - id = create_bayes_abs[0]["experiment_id"] - - # First, get a draw - response = client.get( - f"/bayes_ab/{id}/draw", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - assert response.status_code == 200 - draw_id = response.json()["draw_id"] - - # Then update with an observation - response = client.put( - f"/bayes_ab/{id}/{draw_id}/0.5", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - assert response.status_code == 200 - - # Get outcomes - response = client.get( - f"/bayes_ab/{id}/outcomes", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - assert response.status_code == 200 - assert len(response.json()) == 1 - - @mark.parametrize("create_bayes_ab_payload", ["base_normal"], indirect=True) - def test_get_arms( - self, - client: TestClient, - create_bayes_abs: list, - create_bayes_ab_payload: dict, - workspace_api_key: str, - ) -> None: - id = create_bayes_abs[0]["experiment_id"] - - # First, get a draw - response = client.get( - f"/bayes_ab/{id}/draw", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - assert response.status_code == 200 - draw_id = response.json()["draw_id"] - - # Then update with an observation - response = client.put( - f"/bayes_ab/{id}/{draw_id}/0.5", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - assert response.status_code == 200 - - # Get arms - response = client.get( - f"/bayes_ab/{id}/arms", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - assert response.status_code == 200 - assert len(response.json()) == 2 - - -class TestNotifications: - @fixture() - def create_bayes_ab_payload(self, request: FixtureRequest) -> dict: - payload: dict = copy.deepcopy(base_normal_payload) - payload["arms"] = list(payload["arms"]) - - match request.param: - case "base": - pass - case "daysElapsed_only": - payload["notifications"]["onTrialCompletion"] = False - payload["notifications"]["onDaysElapsed"] = True - case "trialCompletion_only": - payload["notifications"]["onTrialCompletion"] = True - case "percentBetter_only": - payload["notifications"]["onTrialCompletion"] = False - payload["notifications"]["onPercentBetter"] = True - case "all_notifications": - payload["notifications"]["onDaysElapsed"] = True - payload["notifications"]["onPercentBetter"] = True - case "no_notifications": - payload["notifications"]["onTrialCompletion"] = False - case "daysElapsed_missing": - payload["notifications"]["daysElapsed"] = 0 - payload["notifications"]["onDaysElapsed"] = True - case "trialCompletion_missing": - payload["notifications"]["numberOfTrials"] = 0 - payload["notifications"]["onTrialCompletion"] = True - case "percentBetter_missing": - payload["notifications"]["percentBetterThreshold"] = 0 - payload["notifications"]["onPercentBetter"] = True - case _: - raise ValueError("Invalid parameter") - - return payload - - @mark.parametrize( - "create_bayes_ab_payload, expected_response", - [ - ("base", 200), - ("daysElapsed_only", 200), - ("trialCompletion_only", 200), - ("percentBetter_only", 200), - ("all_notifications", 200), - ("no_notifications", 200), - ("daysElapsed_missing", 422), - ("trialCompletion_missing", 422), - ("percentBetter_missing", 422), - ], - indirect=["create_bayes_ab_payload"], - ) - def test_notifications( - self, - client: TestClient, - admin_token: str, - create_bayes_ab_payload: dict, - expected_response: int, - clean_bayes_ab: None, - ) -> None: - response = client.post( - "/bayes_ab", - json=create_bayes_ab_payload, - headers={"Authorization": f"Bearer {admin_token}"}, - ) - - assert response.status_code == expected_response diff --git a/backend/tests/test_cmabs.py b/backend/tests/test_cmabs.py deleted file mode 100644 index d9b6ed0..0000000 --- a/backend/tests/test_cmabs.py +++ /dev/null @@ -1,452 +0,0 @@ -import copy -import os -from typing import Generator - -import numpy as np -from fastapi.testclient import TestClient -from pytest import FixtureRequest, fixture, mark -from sqlalchemy.orm import Session - -from backend.app.contextual_mab.models import ( - ContextDB, - ContextualArmDB, - ContextualBanditDB, -) -from backend.app.models import NotificationsDB - -base_normal_payload = { - "name": "Test", - "description": "Test description", - "prior_type": "normal", - "reward_type": "real-valued", - "sticky_assignment": False, - "arms": [ - { - "name": "arm 1", - "description": "arm 1 description", - "mu_init": 0, - "sigma_init": 1, - }, - { - "name": "arm 2", - "description": "arm 2 description", - "mu_init": 0, - "sigma_init": 1, - }, - ], - "contexts": [ - { - "name": "Context 1", - "description": "context 1 description", - "value_type": "binary", - }, - { - "name": "Context 2", - "description": "context 2 description", - "value_type": "real-valued", - }, - ], - "notifications": { - "onTrialCompletion": True, - "numberOfTrials": 2, - "onDaysElapsed": False, - "daysElapsed": 3, - "onPercentBetter": False, - "percentBetterThreshold": 5, - }, -} - -base_binary_normal_payload = base_normal_payload.copy() -base_binary_normal_payload["reward_type"] = "binary" - - -@fixture -def admin_token(client: TestClient) -> str: - response = client.post( - "/login", - data={ - "username": os.environ.get("ADMIN_USERNAME", ""), - "password": os.environ.get("ADMIN_PASSWORD", ""), - }, - ) - token = response.json()["access_token"] - return token - - -@fixture -def clean_cmabs(db_session: Session) -> Generator: - yield - db_session.query(NotificationsDB).delete() - db_session.query(ContextualArmDB).delete() - db_session.query(ContextDB).delete() - db_session.query(ContextualBanditDB).delete() - db_session.commit() - - -class TestCMab: - @fixture - def create_cmab_payload(self, request: FixtureRequest) -> dict: - payload_normal: dict = copy.deepcopy(base_normal_payload) - payload_normal["arms"] = list(payload_normal["arms"]) - payload_normal["contexts"] = list(payload_normal["contexts"]) - - payload_binary_normal: dict = copy.deepcopy(base_binary_normal_payload) - payload_binary_normal["arms"] = list(payload_binary_normal["arms"]) - payload_binary_normal["contexts"] = list(payload_binary_normal["contexts"]) - - if request.param == "base_normal": - return payload_normal - if request.param == "base_binary_normal": - return payload_binary_normal - if request.param == "one_arm": - payload_normal["arms"].pop() - return payload_normal - if request.param == "no_notifications": - payload_normal["notifications"]["onTrialCompletion"] = False - return payload_normal - if request.param == "invalid_prior": - payload_normal["prior_type"] = "beta" - return payload_normal - if request.param == "invalid_reward": - payload_normal["reward_type"] = "invalid" - return payload_normal - if request.param == "invalid_sigma": - payload_normal["arms"][0]["sigma_init"] = 0 - return payload_normal - if request.param == "with_sticky_assignment": - payload_normal["sticky_assignment"] = True - return payload_normal - - else: - raise ValueError("Invalid parameter") - - @mark.parametrize( - "create_cmab_payload, expected_response", - [ - ("base_normal", 200), - ("base_binary_normal", 200), - ("one_arm", 422), - ("no_notifications", 200), - ("invalid_prior", 422), - ("invalid_reward", 422), - ("invalid_sigma", 422), - ], - indirect=["create_cmab_payload"], - ) - def test_create_cmab( - self, - create_cmab_payload: dict, - client: TestClient, - expected_response: int, - admin_token: str, - clean_cmabs: None, - ) -> None: - response = client.post( - "/contextual_mab", - json=create_cmab_payload, - headers={"Authorization": f"Bearer {admin_token}"}, - ) - - assert response.status_code == expected_response - - @fixture - def create_cmabs( - self, - client: TestClient, - admin_token: str, - request: FixtureRequest, - create_cmab_payload: dict, - ) -> Generator: - cmabs = [] - n_cmabs = request.param if hasattr(request, "param") else 1 - for _ in range(n_cmabs): - response = client.post( - "/contextual_mab", - json=create_cmab_payload, - headers={"Authorization": f"Bearer {admin_token}"}, - ) - cmabs.append(response.json()) - yield cmabs - for cmab in cmabs: - client.delete( - f"/contextual_mab/{cmab['experiment_id']}", - headers={"Authorization": f"Bearer {admin_token}"}, - ) - - @mark.parametrize( - "create_cmabs, n_expected, create_cmab_payload", - [(0, 0, "base_normal"), (2, 2, "base_normal"), (5, 5, "base_normal")], - indirect=["create_cmabs", "create_cmab_payload"], - ) - def test_get_all_cmabs( - self, - client: TestClient, - admin_token: str, - n_expected: int, - create_cmab_payload: dict, - create_cmabs: list, - ) -> None: - response = client.get( - "/contextual_mab", headers={"Authorization": f"Bearer {admin_token}"} - ) - assert response.status_code == 200 - assert len(response.json()) == n_expected - - @mark.parametrize( - "create_cmabs, expected_response, create_cmab_payload", - [(0, 404, "base_normal"), (2, 200, "base_normal")], - indirect=["create_cmabs", "create_cmab_payload"], - ) - def test_get_cmab( - self, - client: TestClient, - admin_token: str, - create_cmab_payload: dict, - create_cmabs: list, - expected_response: int, - ) -> None: - id = create_cmabs[0]["experiment_id"] if create_cmabs else 999 - - response = client.get( - f"/contextual_mab/{id}", headers={"Authorization": f"Bearer {admin_token}"} - ) - assert response.status_code == expected_response - - @mark.parametrize("create_cmab_payload", ["base_normal"], indirect=True) - def test_draw_arm_draw_id_provided( - self, - client: TestClient, - create_cmabs: list, - create_cmab_payload: dict, - workspace_api_key: str, - ) -> None: - id = create_cmabs[0]["experiment_id"] - response = client.post( - f"/contextual_mab/{id}/draw", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - params={"draw_id": "test_draw_id"}, - json=[ - {"context_id": 1, "context_value": 0}, - {"context_id": 2, "context_value": 0.5}, - ], - ) - assert response.status_code == 200 - assert response.json()["draw_id"] == "test_draw_id" - - @mark.parametrize("create_cmab_payload", ["base_normal"], indirect=True) - def test_draw_arm_no_draw_id_provided( - self, - client: TestClient, - create_cmabs: list, - create_cmab_payload: dict, - workspace_api_key: str, - ) -> None: - id = create_cmabs[0]["experiment_id"] - response = client.post( - f"/contextual_mab/{id}/draw", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - json=[ - {"context_id": 1, "context_value": 0}, - {"context_id": 2, "context_value": 0.5}, - ], - ) - assert response.status_code == 200 - assert len(response.json()["draw_id"]) == 36 - - @mark.parametrize( - "create_cmab_payload, client_id, expected_response", - [ - ("with_sticky_assignment", None, 400), - ("with_sticky_assignment", "test_client_id", 200), - ], - indirect=["create_cmab_payload"], - ) - def test_draw_arm_sticky_assignment_client_id_provided( - self, - client: TestClient, - create_cmabs: list, - create_cmab_payload: dict, - client_id: str | None, - expected_response: int, - workspace_api_key: str, - ) -> None: - id = create_cmabs[0]["experiment_id"] - url = f"/contextual_mab/{id}/draw" - if client_id: - url += f"?client_id={client_id}" - - response = client.post( - url, - headers={"Authorization": f"Bearer {workspace_api_key}"}, - json=[ - {"context_id": 1, "context_value": 0}, - {"context_id": 2, "context_value": 0.5}, - ], - ) - assert response.status_code == expected_response - - @mark.parametrize("create_cmab_payload", ["with_sticky_assignment"], indirect=True) - def test_draw_arm_with_sticky_assignment( - self, - client: TestClient, - create_cmabs: list, - create_cmab_payload: dict, - workspace_api_key: str, - ) -> None: - id = create_cmabs[0]["experiment_id"] - arm_ids = [] - - for _ in range(10): - response = client.post( - f"/contextual_mab/{id}/draw?client_id=123", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - json=[ - {"context_id": 1, "context_value": 0}, - {"context_id": 2, "context_value": 1}, - ], - ) - arm_ids.append(response.json()["arm"]["arm_id"]) - - assert np.unique(arm_ids).size == 1 - - @mark.parametrize("create_cmab_payload", ["base_normal"], indirect=True) - def test_one_outcome_per_draw( - self, - client: TestClient, - create_cmabs: list, - create_cmab_payload: dict, - workspace_api_key: str, - ) -> None: - id = create_cmabs[0]["experiment_id"] - response = client.post( - f"/contextual_mab/{id}/draw", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - json=[ - {"context_id": 1, "context_value": 0}, - {"context_id": 2, "context_value": 0.5}, - ], - ) - assert response.status_code == 200 - draw_id = response.json()["draw_id"] - - response = client.put( - f"/contextual_mab/{id}/{draw_id}/1", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - - assert response.status_code == 200 - - response = client.put( - f"/contextual_mab/{id}/{draw_id}/1", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - - assert response.status_code == 400 - - @mark.parametrize( - "n_draws, create_cmab_payload", - [(0, "base_normal"), (1, "base_normal"), (5, "base_normal")], - indirect=["create_cmab_payload"], - ) - def test_get_outcomes( - self, - client: TestClient, - create_cmabs: list, - n_draws: int, - create_cmab_payload: dict, - workspace_api_key: str, - ) -> None: - id = create_cmabs[0]["experiment_id"] - - for _ in range(n_draws): - response = client.post( - f"/contextual_mab/{id}/draw", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - json=[ - {"context_id": 1, "context_value": 0}, - {"context_id": 2, "context_value": 0.5}, - ], - ) - assert response.status_code == 200 - draw_id = response.json()["draw_id"] - response = client.put( - f"/contextual_mab/{id}/{draw_id}/1", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - - response = client.get( - f"/contextual_mab/{id}/outcomes", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - - assert response.status_code == 200 - assert len(response.json()) == n_draws - - -class TestNotifications: - @fixture() - def create_cmab_payload(self, request: FixtureRequest) -> dict: - payload: dict = copy.deepcopy(base_normal_payload) - payload["arms"] = list(payload["arms"]) - payload["contexts"] = list(payload["contexts"]) - - match request.param: - case "base": - pass - case "daysElapsed_only": - payload["notifications"]["onTrialCompletion"] = False - payload["notifications"]["onDaysElapsed"] = True - case "trialCompletion_only": - payload["notifications"]["onTrialCompletion"] = True - case "percentBetter_only": - payload["notifications"]["onTrialCompletion"] = False - payload["notifications"]["onPercentBetter"] = True - case "all_notifications": - payload["notifications"]["onDaysElapsed"] = True - payload["notifications"]["onPercentBetter"] = True - case "no_notifications": - payload["notifications"]["onTrialCompletion"] = False - case "daysElapsed_missing": - payload["notifications"]["daysElapsed"] = 0 - payload["notifications"]["onDaysElapsed"] = True - case "trialCompletion_missing": - payload["notifications"]["numberOfTrials"] = 0 - payload["notifications"]["onTrialCompletion"] = True - case "percentBetter_missing": - payload["notifications"]["percentBetterThreshold"] = 0 - payload["notifications"]["onPercentBetter"] = True - case _: - raise ValueError("Invalid parameter") - - return payload - - @mark.parametrize( - "create_cmab_payload, expected_response", - [ - ("base", 200), - ("daysElapsed_only", 200), - ("trialCompletion_only", 200), - ("percentBetter_only", 200), - ("all_notifications", 200), - ("no_notifications", 200), - ("daysElapsed_missing", 422), - ("trialCompletion_missing", 422), - ("percentBetter_missing", 422), - ], - indirect=["create_cmab_payload"], - ) - def test_notifications( - self, - client: TestClient, - admin_token: str, - create_cmab_payload: dict, - expected_response: int, - clean_cmabs: None, - ) -> None: - response = client.post( - "/contextual_mab", - json=create_cmab_payload, - headers={"Authorization": f"Bearer {admin_token}"}, - ) - - assert response.status_code == expected_response diff --git a/backend/tests/test_experiments.py b/backend/tests/test_experiments.py new file mode 100644 index 0000000..f2bdd7c --- /dev/null +++ b/backend/tests/test_experiments.py @@ -0,0 +1,604 @@ +import copy +import os +from typing import Generator + +from fastapi.testclient import TestClient +from pytest import FixtureRequest, fixture, mark +from sqlalchemy.orm import Session + +from backend.app.experiments.models import ( + ArmDB, + ContextDB, + ExperimentDB, + NotificationsDB, +) + +mab_beta_binom_payload = { + "name": "Test", + "description": "Test description.", + "exp_type": "mab", + "prior_type": "beta", + "reward_type": "binary", + "arms": [ + { + "name": "arm 1", + "description": "arm 1 description.", + "alpha_init": 5, + "beta_init": 1, + "is_treatment_arm": True, + }, + { + "name": "arm 2", + "description": "arm 2 description.", + "alpha_init": 1, + "beta_init": 4, + "is_treatment_arm": False, + }, + ], + "notifications": { + "onTrialCompletion": True, + "numberOfTrials": 2, + "onDaysElapsed": False, + "daysElapsed": 3, + "onPercentBetter": False, + "percentBetterThreshold": 5, + }, + "contexts": [], + "clients": [], +} + + +@fixture +def admin_token(client: TestClient) -> str: + response = client.post( + "/login", + data={ + "username": os.environ.get("ADMIN_USERNAME", ""), + "password": os.environ.get("ADMIN_PASSWORD", ""), + }, + ) + token = response.json()["access_token"] + return token + + +@fixture +def clean_experiments(db_session: Session) -> Generator: + yield + db_session.query(NotificationsDB).delete() + db_session.query(ContextDB).delete() + db_session.query(ArmDB).delete() + db_session.query(ExperimentDB).delete() + db_session.commit() + + +def _get_experiment_payload(input: str) -> dict: + """Helper function to get the experiment payload based on input.""" + payload_mab_beta_binom: dict = copy.deepcopy(mab_beta_binom_payload) + payload_mab_beta_binom["arms"] = list(payload_mab_beta_binom["arms"]) + + payload_mab_normal: dict = copy.deepcopy(mab_beta_binom_payload) + payload_mab_normal["prior_type"] = "normal" + payload_mab_normal["reward_type"] = "real-valued" + payload_mab_normal["arms"] = [ + { + "name": "arm 1", + "description": "arm 1 description", + "mu_init": 2, + "sigma_init": 3, + "is_treatment_arm": True, + }, + { + "name": "arm 2", + "description": "arm 2 description", + "mu_init": 3, + "sigma_init": 7, + "is_treatment_arm": True, + }, + ] + + match input: + case "base_beta_binom": + return payload_mab_beta_binom + case "base_normal": + return payload_mab_normal + case "one_arm": + payload_mab_beta_binom["arms"].pop() + return payload_mab_beta_binom + case "no_notifications": + payload_mab_beta_binom["notifications"]["onTrialCompletion"] = False + return payload_mab_beta_binom + case "invalid_prior": + payload_mab_beta_binom["prior_type"] = "invalid" + return payload_mab_beta_binom + case "invalid_reward": + payload_mab_beta_binom["reward_type"] = "invalid" + return payload_mab_beta_binom + case "invalid_alpha": + payload_mab_beta_binom["arms"][0]["alpha_init"] = -1 + return payload_mab_beta_binom + case "invalid_beta": + payload_mab_beta_binom["arms"][0]["beta_init"] = -1 + return payload_mab_beta_binom + case "invalid_combo": + payload_mab_beta_binom["reward_type"] = "real-valued" + return payload_mab_beta_binom + case "incorrect_params": + payload_mab_beta_binom["arms"][0].pop("alpha_init") + return payload_mab_beta_binom + case "invalid_sigma": + payload_mab_normal["arms"][0]["sigma_init"] = 0.0 + return payload_mab_normal + case "invalid_context_input": + payload_mab_beta_binom["contexts"] = [ + { + "name": "context 1", + "description": "context 1 description", + "value_type": "binary", + } + ] + return payload_mab_beta_binom + case "bayes_ab_normal_binom": + payload_mab_normal["exp_type"] = "bayes_ab" + payload_mab_normal["reward_type"] = "real-valued" + payload_mab_normal["arms"][1]["is_treatment_arm"] = False + return payload_mab_normal + case "bayes_ab_invalid_prior": + payload_mab_beta_binom["exp_type"] = "bayes_ab" + payload_mab_beta_binom["arms"][1]["is_treatment_arm"] = False + return payload_mab_beta_binom + case "bayes_ab_invalid_arm": + payload_mab_normal["exp_type"] = "bayes_ab" + payload_mab_normal["reward_type"] = "real-valued" + return payload_mab_normal + case "bayes_ab_invalid_context": + payload_mab_normal["exp_type"] = "bayes_ab" + payload_mab_normal["reward_type"] = "real-valued" + payload_mab_normal["arms"][1]["is_treatment_arm"] = False + payload_mab_normal["contexts"] = [ + { + "name": "context 1", + "description": "context 1 description", + "value_type": "binary", + } + ] + return payload_mab_normal + case "cmab_normal": + payload_mab_normal["exp_type"] = "cmab" + payload_mab_normal["contexts"] = [ + { + "name": "context 1", + "description": "context 1 description", + "value_type": "binary", + }, + { + "name": "context 2", + "description": "context 2 description", + "value_type": "real-valued", + }, + ] + return payload_mab_normal + case "cmab_normal_binomial": + payload_mab_normal["exp_type"] = "cmab" + payload_mab_normal["reward_type"] = "binary" + payload_mab_normal["contexts"] = [ + { + "name": "context 1", + "description": "context 1 description", + "value_type": "binary", + }, + { + "name": "context 2", + "description": "context 2 description", + "value_type": "real-valued", + }, + ] + return payload_mab_normal + case "cmab_invalid_prior": + payload_mab_normal["exp_type"] = "cmab" + payload_mab_normal["prior_type"] = "beta" + payload_mab_normal["contexts"] = [ + { + "name": "context 1", + "description": "context 1 description", + "value_type": "binary", + }, + { + "name": "context 2", + "description": "context 2 description", + "value_type": "real-valued", + }, + ] + return payload_mab_normal + case "cmab_invalid_context": + payload_mab_normal["exp_type"] = "cmab" + return payload_mab_normal + + case _: + raise ValueError(f"Invalid input: {input}.") + + +class TestExperiment: + @fixture + def create_experiment_payload(self, request: FixtureRequest) -> dict: + """Fixture to create experiment payload based on request parameter.""" + return _get_experiment_payload(request.param) + + @mark.parametrize( + "create_experiment_payload, expected_response", + [ + ("base_beta_binom", 200), + ("base_normal", 200), + ("one_arm", 422), + ("no_notifications", 200), + ("invalid_prior", 422), + ("invalid_reward", 422), + ("invalid_alpha", 422), + ("invalid_beta", 422), + ("invalid_sigma", 422), + ("invalid_combo", 422), + ("incorrect_params", 422), + ("invalid_context_input", 422), + ("bayes_ab_normal_binom", 200), + ("bayes_ab_invalid_prior", 422), + ("bayes_ab_invalid_arm", 422), + ("bayes_ab_invalid_context", 422), + ("cmab_normal", 200), + ("cmab_normal_binomial", 200), + ("cmab_invalid_prior", 422), + ("cmab_invalid_context", 422), + ], + indirect=["create_experiment_payload"], + ) + def test_create_experiment( + self, + create_experiment_payload: dict, + client: TestClient, + expected_response: int, + admin_token: str, + clean_experiments: None, + ) -> None: + response = client.post( + "/experiment", + json=create_experiment_payload, + headers={"Authorization": f"Bearer {admin_token}"}, + ) + + assert response.status_code == expected_response + + @fixture + def create_experiments( + self, + client: TestClient, + admin_token: str, + request: FixtureRequest, + create_experiment_payload: dict, + ) -> Generator: + experiments = [] + n_experiments = request.param if hasattr(request, "param") else 1 + for _ in range(n_experiments): + response = client.post( + "/experiment", + json=create_experiment_payload, + headers={"Authorization": f"Bearer {admin_token}"}, + ) + experiments.append(response.json()) + yield experiments + for experiment in experiments: + client.delete( + f"/experiment/id/{experiment['experiment_id']}", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + + @fixture + def create_mixed_experiments( + self, + client: TestClient, + admin_token: str, + request: FixtureRequest, + ) -> Generator: + mixed_payload = [] + for param in request.param: + payload = _get_experiment_payload(param) + response = client.post( + "/experiment", + json=payload, + headers={"Authorization": f"Bearer {admin_token}"}, + ) + mixed_payload.append(response.json()) + yield mixed_payload + for experiment in mixed_payload: + client.delete( + f"/experiment/id/{experiment['experiment_id']}", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + + @mark.parametrize( + "create_experiments, create_experiment_payload, n_expected", + [ + (0, "base_beta_binom", 0), + (2, "base_beta_binom", 2), + (5, "base_beta_binom", 5), + ], + indirect=["create_experiments", "create_experiment_payload"], + ) + def test_get_all_experiments( + self, + client: TestClient, + admin_token: str, + n_expected: int, + create_experiments: list, + create_experiment_payload: dict, + ) -> None: + response = client.get( + "/experiment", headers={"Authorization": f"Bearer {admin_token}"} + ) + assert response.status_code == 200 + assert len(response.json()) == n_expected + + @mark.parametrize( + "create_mixed_experiments, exp_type, n_expected", + [ + ( + [ + "base_beta_binom", + "base_normal", + "bayes_ab_normal_binom", + "cmab_normal", + ], + "mab", + 2, + ), + ( + [ + "base_beta_binom", + "bayes_ab_normal_binom", + "bayes_ab_normal_binom", + "cmab_normal", + ], + "bayes_ab", + 2, + ), + ( + [ + "base_beta_binom", + "bayes_ab_normal_binom", + "cmab_normal", + "cmab_normal_binomial", + ], + "cmab", + 2, + ), + ], + indirect=["create_mixed_experiments"], + ) + def test_get_all_experiments_by_type( + self, + client: TestClient, + admin_token: str, + n_expected: int, + create_mixed_experiments: list, + exp_type: str, + ) -> None: + response = client.get( + f"/experiment/type/{exp_type}", + headers={"Authorization": f"Bearer {admin_token}"}, + ) + assert response.status_code == 200 + assert len(response.json()) == n_expected + + @mark.parametrize( + "create_experiments, create_experiment_payload, expected_response", + [(0, "base_beta_binom", 404), (2, "base_beta_binom", 200)], + indirect=["create_experiments", "create_experiment_payload"], + ) + def test_get_experiment( + self, + client: TestClient, + admin_token: str, + create_experiments: list, + create_experiment_payload: dict, + expected_response: int, + ) -> None: + id = create_experiments[0]["experiment_id"] if create_experiments else 999 + + response = client.get( + f"/experiment/id/{id}/", headers={"Authorization": f"Bearer {admin_token}"} + ) + assert response.status_code == expected_response + + @mark.parametrize("create_experiment_payload", ["base_beta_binom"], indirect=True) + def test_draw_arm_draw_id_provided( + self, + client: TestClient, + create_experiments: list, + create_experiment_payload: dict, + workspace_api_key: str, + ) -> None: + id = create_experiments[0]["experiment_id"] + response = client.put( + f"/experiment/{id}/draw", + params={"draw_id": "test_draw"}, + headers={"Authorization": f"Bearer {workspace_api_key}"}, + ) + assert response.status_code == 200 + assert response.json()["draw_id"] == "test_draw" + + @mark.parametrize("create_experiment_payload", ["base_beta_binom"], indirect=True) + def test_draw_arm_no_draw_id_provided( + self, + client: TestClient, + create_experiments: list, + create_experiment_payload: dict, + workspace_api_key: str, + ) -> None: + id = create_experiments[0]["experiment_id"] + response = client.put( + f"/experiment/{id}/draw", + headers={"Authorization": f"Bearer {workspace_api_key}"}, + ) + assert response.status_code == 200 + assert len(response.json()["draw_id"]) == 36 + + @mark.parametrize( + "create_experiment_payload", + ["base_beta_binom", "bayes_ab_normal_binom", "cmab_normal"], + indirect=True, + ) + def test_one_outcome_per_draw( + self, + client: TestClient, + create_experiments: list, + create_experiment_payload: dict, + workspace_api_key: str, + ) -> None: + id = create_experiments[0]["experiment_id"] + exp_type = create_experiments[0]["exp_type"] + contexts = None + if exp_type == "cmab": + contexts = [ + {"context_id": context["context_id"], "context_value": 1} + for context in create_experiments[0]["contexts"] + ] + response = client.put( + f"/experiment/{id}/draw", + headers={"Authorization": f"Bearer {workspace_api_key}"}, + json=contexts, + ) + assert response.status_code == 200 + draw_id = response.json()["draw_id"] + + response = client.put( + f"/experiment/{id}/{draw_id}/1", + headers={"Authorization": f"Bearer {workspace_api_key}"}, + ) + + assert response.status_code == 200 + + response = client.put( + f"/experiment/{id}/{draw_id}/1", + headers={"Authorization": f"Bearer {workspace_api_key}"}, + ) + + assert response.status_code == 400 + + @mark.parametrize( + "n_draws, create_experiment_payload", + [ + (0, "base_beta_binom"), + (1, "base_beta_binom"), + (5, "base_beta_binom"), + (0, "bayes_ab_normal_binom"), + (1, "bayes_ab_normal_binom"), + (5, "bayes_ab_normal_binom"), + (0, "cmab_normal"), + (1, "cmab_normal"), + (5, "cmab_normal"), + ], + indirect=["create_experiment_payload"], + ) + def test_get_rewards( + self, + client: TestClient, + create_experiments: list, + n_draws: int, + create_experiment_payload: dict, + workspace_api_key: str, + ) -> None: + id = create_experiments[0]["experiment_id"] + exp_type = create_experiments[0]["exp_type"] + contexts = None + if exp_type == "cmab": + contexts = [ + {"context_id": context["context_id"], "context_value": 1} + for context in create_experiments[0]["contexts"] + ] + + for _ in range(n_draws): + response = client.put( + f"/experiment/{id}/draw", + headers={"Authorization": f"Bearer {workspace_api_key}"}, + json=contexts, + ) + assert response.status_code == 200 + draw_id = response.json()["draw_id"] + # put outcomes + response = client.put( + f"/experiment/{id}/{draw_id}/1", + headers={"Authorization": f"Bearer {workspace_api_key}"}, + ) + + response = client.get( + f"/experiment/{id}/rewards", + headers={"Authorization": f"Bearer {workspace_api_key}"}, + ) + + assert response.status_code == 200 + assert len(response.json()) == n_draws + + +class TestNotifications: + @fixture() + def create_experiment_payload(self, request: FixtureRequest) -> dict: + payload: dict = copy.deepcopy(mab_beta_binom_payload) + payload["arms"] = list(payload["arms"]) + + match request.param: + case "base": + pass + case "daysElapsed_only": + payload["notifications"]["onTrialCompletion"] = False + payload["notifications"]["onDaysElapsed"] = True + case "trialCompletion_only": + payload["notifications"]["onTrialCompletion"] = True + case "percentBetter_only": + payload["notifications"]["onTrialCompletion"] = False + payload["notifications"]["onPercentBetter"] = True + case "all_notifications": + payload["notifications"]["onDaysElapsed"] = True + payload["notifications"]["onPercentBetter"] = True + case "no_notifications": + payload["notifications"]["onTrialCompletion"] = False + case "daysElapsed_missing": + payload["notifications"]["daysElapsed"] = 0 + payload["notifications"]["onDaysElapsed"] = True + case "trialCompletion_missing": + payload["notifications"]["numberOfTrials"] = 0 + payload["notifications"]["onTrialCompletion"] = True + case "percentBetter_missing": + payload["notifications"]["percentBetterThreshold"] = 0 + payload["notifications"]["onPercentBetter"] = True + case _: + raise ValueError("Invalid parameter") + + return payload + + @mark.parametrize( + "create_experiment_payload, expected_response", + [ + ("base", 200), + ("daysElapsed_only", 200), + ("trialCompletion_only", 200), + ("percentBetter_only", 200), + ("all_notifications", 200), + ("no_notifications", 200), + ("daysElapsed_missing", 422), + ("trialCompletion_missing", 422), + ("percentBetter_missing", 422), + ], + indirect=["create_experiment_payload"], + ) + def test_notifications( + self, + client: TestClient, + admin_token: str, + create_experiment_payload: dict, + expected_response: int, + clean_experiments: None, + ) -> None: + response = client.post( + "/experiment", + json=create_experiment_payload, + headers={"Authorization": f"Bearer {admin_token}"}, + ) + + assert response.status_code == expected_response diff --git a/backend/tests/test_mabs.py b/backend/tests/test_mabs.py deleted file mode 100644 index e65ccb6..0000000 --- a/backend/tests/test_mabs.py +++ /dev/null @@ -1,463 +0,0 @@ -import copy -import os -from typing import Generator - -import numpy as np -from fastapi.testclient import TestClient -from pytest import FixtureRequest, fixture, mark -from sqlalchemy.orm import Session - -from backend.app.mab.models import MABArmDB, MultiArmedBanditDB -from backend.app.models import NotificationsDB - -base_beta_binom_payload = { - "name": "Test", - "description": "Test description", - "prior_type": "beta", - "reward_type": "binary", - "sticky_assignment": False, - "arms": [ - { - "name": "arm 1", - "description": "arm 1 description", - "alpha_init": 5, - "beta_init": 1, - }, - { - "name": "arm 2", - "description": "arm 2 description", - "alpha_init": 1, - "beta_init": 4, - }, - ], - "notifications": { - "onTrialCompletion": True, - "numberOfTrials": 2, - "onDaysElapsed": False, - "daysElapsed": 3, - "onPercentBetter": False, - "percentBetterThreshold": 5, - }, -} - -base_normal_payload = base_beta_binom_payload.copy() -base_normal_payload["prior_type"] = "normal" -base_normal_payload["reward_type"] = "real-valued" -base_normal_payload["arms"] = [ - { - "name": "arm 1", - "description": "arm 1 description", - "mu_init": 2, - "sigma_init": 3, - }, - { - "name": "arm 2", - "description": "arm 2 description", - "mu_init": 3, - "sigma_init": 7, - }, -] - - -@fixture -def admin_token(client: TestClient) -> str: - response = client.post( - "/login", - data={ - "username": os.environ.get("ADMIN_USERNAME", ""), - "password": os.environ.get("ADMIN_PASSWORD", ""), - }, - ) - token = response.json()["access_token"] - return token - - -@fixture -def clean_mabs(db_session: Session) -> Generator: - yield - db_session.query(NotificationsDB).delete() - db_session.query(MABArmDB).delete() - db_session.query(MultiArmedBanditDB).delete() - db_session.commit() - - -class TestMab: - @fixture - def create_mab_payload(self, request: FixtureRequest) -> dict: - payload_beta_binom: dict = copy.deepcopy(base_beta_binom_payload) - payload_beta_binom["arms"] = list(payload_beta_binom["arms"]) - - payload_normal: dict = copy.deepcopy(base_normal_payload) - payload_normal["arms"] = list(payload_normal["arms"]) - - if request.param == "base_beta_binom": - return payload_beta_binom - if request.param == "base_normal": - return payload_normal - if request.param == "one_arm": - payload_beta_binom["arms"].pop() - return payload_beta_binom - if request.param == "no_notifications": - payload_beta_binom["notifications"]["onTrialCompletion"] = False - return payload_beta_binom - if request.param == "invalid_prior": - payload_beta_binom["prior_type"] = "invalid" - return payload_beta_binom - if request.param == "invalid_reward": - payload_beta_binom["reward_type"] = "invalid" - return payload_beta_binom - if request.param == "invalid_alpha": - payload_beta_binom["arms"][0]["alpha_init"] = -1 - return payload_beta_binom - if request.param == "invalid_beta": - payload_beta_binom["arms"][0]["beta_init"] = -1 - return payload_beta_binom - if request.param == "invalid_combo_1": - payload_beta_binom["prior_type"] = "normal" - return payload_beta_binom - if request.param == "invalid_combo_2": - payload_beta_binom["reward_type"] = "continuous" - return payload_beta_binom - if request.param == "incorrect_params": - payload_beta_binom["arms"][0].pop("alpha_init") - return payload_beta_binom - if request.param == "invalid_sigma": - payload_normal["arms"][0]["sigma_init"] = 0.0 - return payload_normal - if request.param == "with_sticky_assignment": - payload_beta_binom["sticky_assignment"] = True - return payload_beta_binom - else: - raise ValueError("Invalid parameter") - - @mark.parametrize( - "create_mab_payload, expected_response", - [ - ("base_beta_binom", 200), - ("base_normal", 200), - ("one_arm", 422), - ("no_notifications", 200), - ("invalid_prior", 422), - ("invalid_reward", 422), - ("invalid_alpha", 422), - ("invalid_beta", 422), - ("invalid_combo_1", 422), - ("invalid_combo_2", 422), - ("incorrect_params", 422), - ("invalid_sigma", 422), - ], - indirect=["create_mab_payload"], - ) - def test_create_mab( - self, - create_mab_payload: dict, - client: TestClient, - expected_response: int, - admin_token: str, - clean_mabs: None, - ) -> None: - response = client.post( - "/mab", - json=create_mab_payload, - headers={"Authorization": f"Bearer {admin_token}"}, - ) - - assert response.status_code == expected_response - - @fixture - def create_mabs( - self, - client: TestClient, - admin_token: str, - request: FixtureRequest, - create_mab_payload: dict, - ) -> Generator: - mabs = [] - n_mabs = request.param if hasattr(request, "param") else 1 - for _ in range(n_mabs): - response = client.post( - "/mab", - json=create_mab_payload, - headers={"Authorization": f"Bearer {admin_token}"}, - ) - mabs.append(response.json()) - yield mabs - for mab in mabs: - client.delete( - f"/mab/{mab['experiment_id']}", - headers={"Authorization": f"Bearer {admin_token}"}, - ) - - @mark.parametrize( - "create_mabs, create_mab_payload, n_expected", - [ - (0, "base_beta_binom", 0), - (2, "base_beta_binom", 2), - (5, "base_beta_binom", 5), - ], - indirect=["create_mabs", "create_mab_payload"], - ) - def test_get_all_mabs( - self, - client: TestClient, - admin_token: str, - n_expected: int, - create_mabs: list, - create_mab_payload: dict, - ) -> None: - response = client.get( - "/mab", headers={"Authorization": f"Bearer {admin_token}"} - ) - assert response.status_code == 200 - assert len(response.json()) == n_expected - - @mark.parametrize( - "create_mabs, create_mab_payload, expected_response", - [(0, "base_beta_binom", 404), (2, "base_beta_binom", 200)], - indirect=["create_mabs", "create_mab_payload"], - ) - def test_get_mab( - self, - client: TestClient, - admin_token: str, - create_mabs: list, - create_mab_payload: dict, - expected_response: int, - ) -> None: - id = create_mabs[0]["experiment_id"] if create_mabs else 999 - - response = client.get( - f"/mab/{id}/", headers={"Authorization": f"Bearer {admin_token}"} - ) - assert response.status_code == expected_response - - @mark.parametrize("create_mab_payload", ["base_beta_binom"], indirect=True) - def test_draw_arm_draw_id_provided( - self, - client: TestClient, - create_mabs: list, - create_mab_payload: dict, - workspace_api_key: str, - ) -> None: - id = create_mabs[0]["experiment_id"] - response = client.get( - f"/mab/{id}/draw", - params={"draw_id": "test_draw"}, - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - assert response.status_code == 200 - assert response.json()["draw_id"] == "test_draw" - - @mark.parametrize("create_mab_payload", ["base_beta_binom"], indirect=True) - def test_draw_arm_no_draw_id_provided( - self, - client: TestClient, - create_mabs: list, - create_mab_payload: dict, - workspace_api_key: str, - ) -> None: - id = create_mabs[0]["experiment_id"] - response = client.get( - f"/mab/{id}/draw", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - assert response.status_code == 200 - assert len(response.json()["draw_id"]) == 36 - - @mark.parametrize( - "create_mab_payload, client_id, expected_response", - [ - ("with_sticky_assignment", None, 400), - ("with_sticky_assignment", "test_client_id", 200), - ], - indirect=["create_mab_payload"], - ) - def test_draw_arm_sticky_assignment_with_client_id( - self, - client: TestClient, - admin_token: str, - create_mab_payload: dict, - create_mabs: list, - client_id: str | None, - expected_response: int, - workspace_api_key: str, - ) -> None: - mabs = create_mabs - id = mabs[0]["experiment_id"] - response = client.get( - f"/mab/{id}/draw{'?client_id=' + client_id if client_id else ''}", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - assert response.status_code == expected_response - - @mark.parametrize("create_mab_payload", ["with_sticky_assignment"], indirect=True) - def test_draw_arm_sticky_assignment_client_id_provided( - self, - client: TestClient, - admin_token: str, - create_mab_payload: dict, - create_mabs: list, - workspace_api_key: str, - ) -> None: - mabs = create_mabs - id = mabs[0]["experiment_id"] - response = client.get( - f"/mab/{id}/draw?client_id=123", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - assert response.status_code == 200 - - @mark.parametrize("create_mab_payload", ["with_sticky_assignment"], indirect=True) - def test_draw_arm_sticky_assignment_similar_arms( - self, - client: TestClient, - admin_token: str, - create_mab_payload: dict, - create_mabs: list, - workspace_api_key: str, - ) -> None: - mabs = create_mabs - id = mabs[0]["experiment_id"] - - arm_ids = [] - for _ in range(10): - response = client.get( - f"/mab/{id}/draw?client_id=123", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - arm_ids.append(response.json()["arm"]["arm_id"]) - assert np.unique(arm_ids).size == 1 - - @mark.parametrize("create_mab_payload", ["base_beta_binom"], indirect=True) - def test_one_outcome_per_draw( - self, - client: TestClient, - create_mabs: list, - create_mab_payload: dict, - workspace_api_key: str, - ) -> None: - id = create_mabs[0]["experiment_id"] - response = client.get( - f"/mab/{id}/draw", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - assert response.status_code == 200 - draw_id = response.json()["draw_id"] - - response = client.put( - f"/mab/{id}/{draw_id}/1", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - - assert response.status_code == 200 - - response = client.put( - f"/mab/{id}/{draw_id}/1", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - - assert response.status_code == 400 - - @mark.parametrize( - "n_draws, create_mab_payload", - [(0, "base_beta_binom"), (1, "base_beta_binom"), (5, "base_beta_binom")], - indirect=["create_mab_payload"], - ) - def test_get_outcomes( - self, - client: TestClient, - create_mabs: list, - n_draws: int, - create_mab_payload: dict, - workspace_api_key: str, - ) -> None: - id = create_mabs[0]["experiment_id"] - - for _ in range(n_draws): - response = client.get( - f"/mab/{id}/draw", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - assert response.status_code == 200 - draw_id = response.json()["draw_id"] - # put outcomes - response = client.put( - f"/mab/{id}/{draw_id}/1", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - - response = client.get( - f"/mab/{id}/outcomes", - headers={"Authorization": f"Bearer {workspace_api_key}"}, - ) - - assert response.status_code == 200 - assert len(response.json()) == n_draws - - -class TestNotifications: - @fixture() - def create_mab_payload(self, request: FixtureRequest) -> dict: - payload: dict = copy.deepcopy(base_beta_binom_payload) - payload["arms"] = list(payload["arms"]) - - match request.param: - case "base": - pass - case "daysElapsed_only": - payload["notifications"]["onTrialCompletion"] = False - payload["notifications"]["onDaysElapsed"] = True - case "trialCompletion_only": - payload["notifications"]["onTrialCompletion"] = True - case "percentBetter_only": - payload["notifications"]["onTrialCompletion"] = False - payload["notifications"]["onPercentBetter"] = True - case "all_notifications": - payload["notifications"]["onDaysElapsed"] = True - payload["notifications"]["onPercentBetter"] = True - case "no_notifications": - payload["notifications"]["onTrialCompletion"] = False - case "daysElapsed_missing": - payload["notifications"]["daysElapsed"] = 0 - payload["notifications"]["onDaysElapsed"] = True - case "trialCompletion_missing": - payload["notifications"]["numberOfTrials"] = 0 - payload["notifications"]["onTrialCompletion"] = True - case "percentBetter_missing": - payload["notifications"]["percentBetterThreshold"] = 0 - payload["notifications"]["onPercentBetter"] = True - case _: - raise ValueError("Invalid parameter") - - return payload - - @mark.parametrize( - "create_mab_payload, expected_response", - [ - ("base", 200), - ("daysElapsed_only", 200), - ("trialCompletion_only", 200), - ("percentBetter_only", 200), - ("all_notifications", 200), - ("no_notifications", 200), - ("daysElapsed_missing", 422), - ("trialCompletion_missing", 422), - ("percentBetter_missing", 422), - ], - indirect=["create_mab_payload"], - ) - def test_notifications( - self, - client: TestClient, - admin_token: str, - create_mab_payload: dict, - expected_response: int, - clean_mabs: None, - ) -> None: - response = client.post( - "/mab", - json=create_mab_payload, - headers={"Authorization": f"Bearer {admin_token}"}, - ) - - assert response.status_code == expected_response diff --git a/backend/tests/test_messages.py b/backend/tests/test_messages.py index 86396b4..253e4f2 100644 --- a/backend/tests/test_messages.py +++ b/backend/tests/test_messages.py @@ -9,19 +9,20 @@ base_mab_payload = { "name": "Test", - "description": "Test description", + "description": "Test description.", + "exp_type": "mab", "prior_type": "beta", "reward_type": "binary", "arms": [ { "name": "arm 1", - "description": "arm 1 description", + "description": "arm 1 description.", "alpha_init": 5, "beta_init": 1, }, { "name": "arm 2", - "description": "arm 2 description", + "description": "arm 2 description.", "alpha_init": 1, "beta_init": 4, }, @@ -34,6 +35,8 @@ "onPercentBetter": False, "percentBetterThreshold": 5, }, + "contexts": [], + "clients": [], } @@ -53,13 +56,13 @@ def admin_token(client: TestClient) -> str: @fixture def experiment_id(client: TestClient, admin_token: str) -> Generator[int, None, None]: response = client.post( - "/mab", + "/experiment", headers={"Authorization": f"Bearer {admin_token}"}, json=base_mab_payload, ) yield response.json()["experiment_id"] client.delete( - f"/mab/{response.json()['experiment_id']}", + f"/experiment/id/{response.json()['experiment_id']}", headers={"Authorization": f"Bearer {admin_token}"}, ) diff --git a/backend/tests/test_notifications_job.py b/backend/tests/test_notifications_job.py index 8b1ef80..911cbfc 100644 --- a/backend/tests/test_notifications_job.py +++ b/backend/tests/test_notifications_job.py @@ -12,21 +12,22 @@ from backend.jobs import create_notifications from backend.jobs.create_notifications import process_notifications -base_mab_payload = { +base_experiment_payload = { "name": "Test", - "description": "Test description", + "description": "Test description.", + "exp_type": "mab", "prior_type": "beta", "reward_type": "binary", "arms": [ { "name": "arm 1", - "description": "arm 1 description", + "description": "arm 1 description.", "alpha_init": 5, "beta_init": 1, }, { "name": "arm 2", - "description": "arm 2 description", + "description": "arm 2 description.", "alpha_init": 1, "beta_init": 4, }, @@ -39,6 +40,8 @@ "onPercentBetter": False, "percentBetterThreshold": 5, }, + "contexts": [], + "clients": [], } @@ -67,65 +70,65 @@ def admin_token(client: TestClient) -> str: class TestNotificationsJob: @fixture - def create_mabs_days_elapsed( + def create_experiments_days_elapsed( self, client: TestClient, admin_token: str, request: FixtureRequest ) -> Generator: - mabs = [] - n_mabs, days_elapsed = request.param + experiments = [] + n_experiments, days_elapsed = request.param - payload: dict = copy.deepcopy(base_mab_payload) + payload: dict = copy.deepcopy(base_experiment_payload) payload["notifications"]["onDaysElapsed"] = True payload["notifications"]["daysElapsed"] = days_elapsed - for _ in range(n_mabs): + for _ in range(n_experiments): response = client.post( - "/mab", + "/experiment", json=payload, headers={"Authorization": f"Bearer {admin_token}"}, ) - mabs.append(response.json()) - yield mabs - for mab in mabs: + experiments.append(response.json()) + yield experiments + for experiment in experiments: client.delete( - f"/mab/{mab['experiment_id']}", + f"/experiment/id/{experiment['experiment_id']}", headers={"Authorization": f"Bearer {admin_token}"}, ) @fixture - def create_mabs_trials_run( + def create_experiments_trials_run( self, client: TestClient, admin_token: str, request: FixtureRequest ) -> Generator: - mabs = [] - n_mabs, n_trials = request.param + experiments = [] + n_experiments, n_trials = request.param - payload: dict = copy.deepcopy(base_mab_payload) + payload: dict = copy.deepcopy(base_experiment_payload) payload["notifications"]["onTrialCompletion"] = True payload["notifications"]["numberOfTrials"] = n_trials - for _ in range(n_mabs): + for _ in range(n_experiments): response = client.post( - "/mab", + "/experiment", json=payload, headers={"Authorization": f"Bearer {admin_token}"}, ) - mabs.append(response.json()) - yield mabs - for mab in mabs: + experiments.append(response.json()) + yield experiments + for experiment in experiments: client.delete( - f"/mab/{mab['experiment_id']}", + f"/experiment/id/{experiment['experiment_id']}", headers={"Authorization": f"Bearer {admin_token}"}, ) @mark.parametrize( - "create_mabs_days_elapsed, days_elapsed", + "create_experiments_days_elapsed, days_elapsed", [((3, 4), 4), ((4, 62), 64), ((3, 40), 40)], - indirect=["create_mabs_days_elapsed"], + indirect=["create_experiments_days_elapsed"], ) async def test_days_elapsed_notification( self, client: TestClient, admin_token: str, - create_mabs_days_elapsed: list[dict], + create_experiments_days_elapsed: list[dict], db_session: Session, days_elapsed: int, monkeypatch: MonkeyPatch, @@ -137,18 +140,18 @@ async def test_days_elapsed_notification( fake_datetime(days_elapsed), ) n_processed = await process_notifications(asession) - assert n_processed == len(create_mabs_days_elapsed) + assert n_processed == len(create_experiments_days_elapsed) @mark.parametrize( - "create_mabs_days_elapsed, days_elapsed", + "create_experiments_days_elapsed, days_elapsed", [((3, 4), 3), ((4, 62), 50), ((3, 40), 0)], - indirect=["create_mabs_days_elapsed"], + indirect=["create_experiments_days_elapsed"], ) async def test_days_elapsed_notification_not_sent( self, client: TestClient, admin_token: str, - create_mabs_days_elapsed: list[dict], + create_experiments_days_elapsed: list[dict], db_session: Session, days_elapsed: int, monkeypatch: MonkeyPatch, @@ -163,16 +166,16 @@ async def test_days_elapsed_notification_not_sent( assert n_processed == 0 @mark.parametrize( - "create_mabs_trials_run, n_trials", + "create_experiments_trials_run, n_trials", [((3, 4), 4), ((4, 62), 64), ((3, 40), 40)], - indirect=["create_mabs_trials_run"], + indirect=["create_experiments_trials_run"], ) async def test_trials_run_notification( self, client: TestClient, admin_token: str, n_trials: int, - create_mabs_trials_run: list[dict], + create_experiments_trials_run: list[dict], db_session: Session, asession: AsyncSession, workspace_api_key: str, @@ -180,11 +183,11 @@ async def test_trials_run_notification( n_processed = await process_notifications(asession) assert n_processed == 0 headers = {"Authorization": f"Bearer {workspace_api_key}"} - for mab in create_mabs_trials_run: + for experiment in create_experiments_trials_run: for i in range(n_trials): - draw_id = f"draw_{i}_{mab['experiment_id']}" - response = client.get( - f"/mab/{mab['experiment_id']}/draw", + draw_id = f"draw_{i}_{experiment['experiment_id']}" + response = client.put( + f"/experiment/{experiment['experiment_id']}/draw", params={"draw_id": draw_id}, headers=headers, ) @@ -192,10 +195,10 @@ async def test_trials_run_notification( assert response.json()["draw_id"] == draw_id response = client.put( - f"/mab/{mab['experiment_id']}/{draw_id}/1", + f"/experiment/{experiment['experiment_id']}/{draw_id}/1", headers=headers, ) assert response.status_code == 200 n_processed = await process_notifications(asession) await asyncio.sleep(0.1) - assert n_processed == len(create_mabs_trials_run) + assert n_processed == len(create_experiments_trials_run) diff --git a/frontend/package-lock.json b/frontend/package-lock.json index bd3f82d..d4c756a 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -83,14 +83,25 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/@babel/runtime": { - "version": "7.27.0", - "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.27.0.tgz", - "integrity": "sha512-VtPOkrdPHZsKc/clNqyi9WUA8TINkZ4cGk63UUE3u4pmB2k+ZMQRDuIOagv8UVd6j7k0T3+RRIb7beKTebNbcw==", - "license": "MIT", + "node_modules/@ampproject/remapping": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/@ampproject/remapping/-/remapping-2.3.0.tgz", + "integrity": "sha512-30iZtAPgz+LTIYoeivqYo853f02jBYSd5uGnGpkFV0M3xOt9aN73erkgYAmZU43x4VfqcnLxW9Kpg3R5LC4YYw==", + "dev": true, + "license": "Apache-2.0", "dependencies": { - "regenerator-runtime": "^0.14.0" + "@jridgewell/gen-mapping": "^0.3.5", + "@jridgewell/trace-mapping": "^0.3.24" }, + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@babel/runtime": { + "version": "7.27.6", + "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.27.6.tgz", + "integrity": "sha512-vbavdySgbTTrmFE+EsiqUTzlOr5bzlnJtUv9PynGCAKvfQqjIXbvFdumPM/GxMDfyuGMJaJAU6TO4zc1Jf1i8Q==", + "license": "MIT", "engines": { "node": ">=6.9.0" } @@ -130,9 +141,9 @@ } }, "node_modules/@eslint-community/eslint-utils": { - "version": "4.6.1", - "resolved": "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.6.1.tgz", - "integrity": "sha512-KTsJMmobmbrFLe3LDh0PC2FXpcSYJt/MLjlkh/9LEnmKYLSYmT/0EW9JWANjeoemiuZrmogti0tW5Ch+qNUYDw==", + "version": "4.7.0", + "resolved": "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.7.0.tgz", + "integrity": "sha512-dyybb3AcajC7uha6CvhdVRJqaKyn7w2YKqKyAN37NKYgZT36w+iRb0Dymmc5qEJ549c/S31cMMSFd75bteCpCw==", "dev": true, "license": "MIT", "dependencies": { @@ -193,28 +204,28 @@ } }, "node_modules/@floating-ui/core": { - "version": "1.6.9", - "resolved": "https://registry.npmjs.org/@floating-ui/core/-/core-1.6.9.tgz", - "integrity": "sha512-uMXCuQ3BItDUbAMhIXw7UPXRfAlOAvZzdK9BWpE60MCn+Svt3aLn9jsPTi/WNGlRUu2uI0v5S7JiIUsbsvh3fw==", + "version": "1.7.1", + "resolved": "https://registry.npmjs.org/@floating-ui/core/-/core-1.7.1.tgz", + "integrity": "sha512-azI0DrjMMfIug/ExbBaeDVJXcY0a7EPvPjb2xAJPa4HeimBX+Z18HK8QQR3jb6356SnDDdxx+hinMLcJEDdOjw==", "license": "MIT", "dependencies": { "@floating-ui/utils": "^0.2.9" } }, "node_modules/@floating-ui/dom": { - "version": "1.6.13", - "resolved": "https://registry.npmjs.org/@floating-ui/dom/-/dom-1.6.13.tgz", - "integrity": "sha512-umqzocjDgNRGTuO7Q8CU32dkHkECqI8ZdMZ5Swb6QAM0t5rnlrN3lGo1hdpscRd3WS8T6DKYK4ephgIH9iRh3w==", + "version": "1.7.1", + "resolved": "https://registry.npmjs.org/@floating-ui/dom/-/dom-1.7.1.tgz", + "integrity": "sha512-cwsmW/zyw5ltYTUeeYJ60CnQuPqmGwuGVhG9w0PRaRKkAyi38BT5CKrpIbb+jtahSwUl04cWzSx9ZOIxeS6RsQ==", "license": "MIT", "dependencies": { - "@floating-ui/core": "^1.6.0", + "@floating-ui/core": "^1.7.1", "@floating-ui/utils": "^0.2.9" } }, "node_modules/@floating-ui/react-dom": { - "version": "2.1.2", - "resolved": "https://registry.npmjs.org/@floating-ui/react-dom/-/react-dom-2.1.2.tgz", - "integrity": "sha512-06okr5cgPzMNBy+Ycse2A6udMi4bqwW/zgBF/rwjcNqWkyr82Mcg8b0vjX8OJpZFy/FKjJmw6wV7t44kK6kW7A==", + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/@floating-ui/react-dom/-/react-dom-2.1.3.tgz", + "integrity": "sha512-huMBfiU9UnQ2oBwIhgzyIiSpVgvlDstU8CX0AF+wS+KzmYMs0J2a3GwuFHV1Lz+jlrQGeC1fF+Nv0QoumyV0bA==", "license": "MIT", "dependencies": { "@floating-ui/dom": "^1.0.0" @@ -324,23 +335,89 @@ "url": "https://github.com/chalk/strip-ansi?sponsor=1" } }, + "node_modules/@isaacs/fs-minipass": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/@isaacs/fs-minipass/-/fs-minipass-4.0.1.tgz", + "integrity": "sha512-wgm9Ehl2jpeqP3zw/7mo3kRHFp5MEDhqAdwy1fTGkHAwnkGOVsgpvQhL8B5n1qlb01jV3n/bI0ZfZp5lWA1k4w==", + "dev": true, + "license": "ISC", + "dependencies": { + "minipass": "^7.0.4" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@jridgewell/gen-mapping": { + "version": "0.3.8", + "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.8.tgz", + "integrity": "sha512-imAbBGkb+ebQyxKgzv5Hu2nmROxoDOXHh80evxdoXNOrvAnVx7zimzc1Oo5h9RlfV4vPXaE2iM5pOFbvOCClWA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@jridgewell/set-array": "^1.2.1", + "@jridgewell/sourcemap-codec": "^1.4.10", + "@jridgewell/trace-mapping": "^0.3.24" + }, + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@jridgewell/resolve-uri": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/@jridgewell/resolve-uri/-/resolve-uri-3.1.2.tgz", + "integrity": "sha512-bRISgCIjP20/tbWSPWMEi54QVPRZExkuD9lJL+UIxUKtwVJA8wW1Trb1jMs1RFXo1CBTNZ/5hpC9QvmKWdopKw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@jridgewell/set-array": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/@jridgewell/set-array/-/set-array-1.2.1.tgz", + "integrity": "sha512-R8gLRTZeyp03ymzP/6Lil/28tGeGEzhx1q2k703KGWRAI1VdvPIXdG70VJc2pAMw3NA6JKL5hhFu1sJX0Mnn/A==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@jridgewell/sourcemap-codec": { + "version": "1.5.0", + "resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.5.0.tgz", + "integrity": "sha512-gv3ZRaISU3fjPAgNsriBRqGWQL6quFx04YMPW/zD8XMLsU32mhCCbfbO6KZFLjvYpCZ8zyDEgqsgf+PwPaM7GQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/@jridgewell/trace-mapping": { + "version": "0.3.25", + "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.25.tgz", + "integrity": "sha512-vNk6aEwybGtawWmy/PzwnGDOjCkLWSD2wqvjGGAgOAwCGWySYXfYoxt00IJkTF+8Lb57DwOb3Aa0o9CApepiYQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@jridgewell/resolve-uri": "^3.1.0", + "@jridgewell/sourcemap-codec": "^1.4.14" + } + }, "node_modules/@napi-rs/wasm-runtime": { - "version": "0.2.9", - "resolved": "https://registry.npmjs.org/@napi-rs/wasm-runtime/-/wasm-runtime-0.2.9.tgz", - "integrity": "sha512-OKRBiajrrxB9ATokgEQoG87Z25c67pCpYcCwmXYX8PBftC9pBfN18gnm/fh1wurSLEKIAt+QRFLFCQISrb66Jg==", + "version": "0.2.10", + "resolved": "https://registry.npmjs.org/@napi-rs/wasm-runtime/-/wasm-runtime-0.2.10.tgz", + "integrity": "sha512-bCsCyeZEwVErsGmyPNSzwfwFn4OdxBj0mmv6hOFucB/k81Ojdu68RbZdxYsRQUPc9l6SU5F/cG+bXgWs3oUgsQ==", "dev": true, "license": "MIT", "optional": true, "dependencies": { - "@emnapi/core": "^1.4.0", - "@emnapi/runtime": "^1.4.0", + "@emnapi/core": "^1.4.3", + "@emnapi/runtime": "^1.4.3", "@tybys/wasm-util": "^0.9.0" } }, "node_modules/@next/env": { - "version": "14.2.28", - "resolved": "https://registry.npmjs.org/@next/env/-/env-14.2.28.tgz", - "integrity": "sha512-PAmWhJfJQlP+kxZwCjrVd9QnR5x0R3u0mTXTiZDgSd4h5LdXmjxCCWbN9kq6hkZBOax8Rm3xDW5HagWyJuT37g==", + "version": "14.2.29", + "resolved": "https://registry.npmjs.org/@next/env/-/env-14.2.29.tgz", + "integrity": "sha512-UzgLR2eBfhKIQt0aJ7PWH7XRPYw7SXz0Fpzdl5THjUnvxy4kfBk9OU4RNPNiETewEEtaBcExNFNn1QWH8wQTjg==", "license": "MIT" }, "node_modules/@next/eslint-plugin-next": { @@ -354,9 +431,9 @@ } }, "node_modules/@next/swc-darwin-arm64": { - "version": "14.2.28", - "resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-14.2.28.tgz", - "integrity": "sha512-kzGChl9setxYWpk3H6fTZXXPFFjg7urptLq5o5ZgYezCrqlemKttwMT5iFyx/p1e/JeglTwDFRtb923gTJ3R1w==", + "version": "14.2.29", + "resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-14.2.29.tgz", + "integrity": "sha512-wWtrAaxCVMejxPHFb1SK/PVV1WDIrXGs9ki0C/kUM8ubKHQm+3hU9MouUywCw8Wbhj3pewfHT2wjunLEr/TaLA==", "cpu": [ "arm64" ], @@ -370,9 +447,9 @@ } }, "node_modules/@next/swc-darwin-x64": { - "version": "14.2.28", - "resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-14.2.28.tgz", - "integrity": "sha512-z6FXYHDJlFOzVEOiiJ/4NG8aLCeayZdcRSMjPDysW297Up6r22xw6Ea9AOwQqbNsth8JNgIK8EkWz2IDwaLQcw==", + "version": "14.2.29", + "resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-14.2.29.tgz", + "integrity": "sha512-7Z/jk+6EVBj4pNLw/JQrvZVrAh9Bv8q81zCFSfvTMZ51WySyEHWVpwCEaJY910LyBftv2F37kuDPQm0w9CEXyg==", "cpu": [ "x64" ], @@ -386,9 +463,9 @@ } }, "node_modules/@next/swc-linux-arm64-gnu": { - "version": "14.2.28", - "resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-14.2.28.tgz", - "integrity": "sha512-9ARHLEQXhAilNJ7rgQX8xs9aH3yJSj888ssSjJLeldiZKR4D7N08MfMqljk77fAwZsWwsrp8ohHsMvurvv9liQ==", + "version": "14.2.29", + "resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-14.2.29.tgz", + "integrity": "sha512-o6hrz5xRBwi+G7JFTHc+RUsXo2lVXEfwh4/qsuWBMQq6aut+0w98WEnoNwAwt7hkEqegzvazf81dNiwo7KjITw==", "cpu": [ "arm64" ], @@ -402,9 +479,9 @@ } }, "node_modules/@next/swc-linux-arm64-musl": { - "version": "14.2.28", - "resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-14.2.28.tgz", - "integrity": "sha512-p6gvatI1nX41KCizEe6JkF0FS/cEEF0u23vKDpl+WhPe/fCTBeGkEBh7iW2cUM0rvquPVwPWdiUR6Ebr/kQWxQ==", + "version": "14.2.29", + "resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-14.2.29.tgz", + "integrity": "sha512-9i+JEHBOVgqxQ92HHRFlSW1EQXqa/89IVjtHgOqsShCcB/ZBjTtkWGi+SGCJaYyWkr/lzu51NTMCfKuBf7ULNw==", "cpu": [ "arm64" ], @@ -418,9 +495,9 @@ } }, "node_modules/@next/swc-linux-x64-gnu": { - "version": "14.2.28", - "resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-14.2.28.tgz", - "integrity": "sha512-nsiSnz2wO6GwMAX2o0iucONlVL7dNgKUqt/mDTATGO2NY59EO/ZKnKEr80BJFhuA5UC1KZOMblJHWZoqIJddpA==", + "version": "14.2.29", + "resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-14.2.29.tgz", + "integrity": "sha512-B7JtMbkUwHijrGBOhgSQu2ncbCYq9E7PZ7MX58kxheiEOwdkM+jGx0cBb+rN5AeqF96JypEppK6i/bEL9T13lA==", "cpu": [ "x64" ], @@ -434,9 +511,9 @@ } }, "node_modules/@next/swc-linux-x64-musl": { - "version": "14.2.28", - "resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-14.2.28.tgz", - "integrity": "sha512-+IuGQKoI3abrXFqx7GtlvNOpeExUH1mTIqCrh1LGFf8DnlUcTmOOCApEnPJUSLrSbzOdsF2ho2KhnQoO0I1RDw==", + "version": "14.2.29", + "resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-14.2.29.tgz", + "integrity": "sha512-yCcZo1OrO3aQ38B5zctqKU1Z3klOohIxug6qdiKO3Q3qNye/1n6XIs01YJ+Uf+TdpZQ0fNrOQI2HrTLF3Zprnw==", "cpu": [ "x64" ], @@ -450,9 +527,9 @@ } }, "node_modules/@next/swc-win32-arm64-msvc": { - "version": "14.2.28", - "resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-14.2.28.tgz", - "integrity": "sha512-l61WZ3nevt4BAnGksUVFKy2uJP5DPz2E0Ma/Oklvo3sGj9sw3q7vBWONFRgz+ICiHpW5mV+mBrkB3XEubMrKaA==", + "version": "14.2.29", + "resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-14.2.29.tgz", + "integrity": "sha512-WnrfeOEtTVidI9Z6jDLy+gxrpDcEJtZva54LYC0bSKQqmyuHzl0ego+v0F/v2aXq0am67BRqo/ybmmt45Tzo4A==", "cpu": [ "arm64" ], @@ -466,9 +543,9 @@ } }, "node_modules/@next/swc-win32-ia32-msvc": { - "version": "14.2.28", - "resolved": "https://registry.npmjs.org/@next/swc-win32-ia32-msvc/-/swc-win32-ia32-msvc-14.2.28.tgz", - "integrity": "sha512-+Kcp1T3jHZnJ9v9VTJ/yf1t/xmtFAc/Sge4v7mVc1z+NYfYzisi8kJ9AsY8itbgq+WgEwMtOpiLLJsUy2qnXZw==", + "version": "14.2.29", + "resolved": "https://registry.npmjs.org/@next/swc-win32-ia32-msvc/-/swc-win32-ia32-msvc-14.2.29.tgz", + "integrity": "sha512-vkcriFROT4wsTdSeIzbxaZjTNTFKjSYmLd8q/GVH3Dn8JmYjUKOuKXHK8n+lovW/kdcpIvydO5GtN+It2CvKWA==", "cpu": [ "ia32" ], @@ -482,9 +559,9 @@ } }, "node_modules/@next/swc-win32-x64-msvc": { - "version": "14.2.28", - "resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-14.2.28.tgz", - "integrity": "sha512-1gCmpvyhz7DkB1srRItJTnmR2UwQPAUXXIg9r0/56g3O8etGmwlX68skKXJOp9EejW3hhv7nSQUJ2raFiz4MoA==", + "version": "14.2.29", + "resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-14.2.29.tgz", + "integrity": "sha512-iPPwUEKnVs7pwR0EBLJlwxLD7TTHWS/AoVZx1l9ZQzfQciqaFEr5AlYzA2uB6Fyby1IF18t4PL0nTpB+k4Tzlw==", "cpu": [ "x64" ], @@ -575,12 +652,12 @@ "license": "MIT" }, "node_modules/@radix-ui/react-accessible-icon": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/@radix-ui/react-accessible-icon/-/react-accessible-icon-1.1.4.tgz", - "integrity": "sha512-J8pIt7l32A9fGIn86vwccQzik5MgIOTtceeTxi6EiiFYwWHLxsTHwiOW4pI5sQhQJWd3MOEkumFBIHwIU038Cw==", + "version": "1.1.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-accessible-icon/-/react-accessible-icon-1.1.7.tgz", + "integrity": "sha512-XM+E4WXl0OqUJFovy6GjmxxFyx9opfCAIUku4dlKRd5YEPqt4kALOkQOp0Of6reHuUkJuiPBEc5k0o4z4lTC8A==", "license": "MIT", "dependencies": { - "@radix-ui/react-visually-hidden": "1.2.0" + "@radix-ui/react-visually-hidden": "1.2.3" }, "peerDependencies": { "@types/react": "*", @@ -598,19 +675,19 @@ } }, "node_modules/@radix-ui/react-accordion": { - "version": "1.2.8", - "resolved": "https://registry.npmjs.org/@radix-ui/react-accordion/-/react-accordion-1.2.8.tgz", - "integrity": "sha512-c7OKBvO36PfQIUGIjj1Wko0hH937pYFU2tR5zbIJDUsmTzHoZVHHt4bmb7OOJbzTaWJtVELKWojBHa7OcnUHmQ==", + "version": "1.2.11", + "resolved": "https://registry.npmjs.org/@radix-ui/react-accordion/-/react-accordion-1.2.11.tgz", + "integrity": "sha512-l3W5D54emV2ues7jjeG1xcyN7S3jnK3zE2zHqgn0CmMsy9lNJwmgcrmaxS+7ipw15FAivzKNzH3d5EcGoFKw0A==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", - "@radix-ui/react-collapsible": "1.1.8", - "@radix-ui/react-collection": "1.1.4", + "@radix-ui/react-collapsible": "1.1.11", + "@radix-ui/react-collection": "1.1.7", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", "@radix-ui/react-id": "1.1.1", - "@radix-ui/react-primitive": "2.1.0", + "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-controllable-state": "1.2.2" }, "peerDependencies": { @@ -629,17 +706,17 @@ } }, "node_modules/@radix-ui/react-alert-dialog": { - "version": "1.1.11", - "resolved": "https://registry.npmjs.org/@radix-ui/react-alert-dialog/-/react-alert-dialog-1.1.11.tgz", - "integrity": "sha512-4KfkwrFnAw3Y5Jeoq6G+JYSKW0JfIS3uDdFC/79Jw9AsMayZMizSSMxk1gkrolYXsa/WzbbDfOA7/D8N5D+l1g==", + "version": "1.1.14", + "resolved": "https://registry.npmjs.org/@radix-ui/react-alert-dialog/-/react-alert-dialog-1.1.14.tgz", + "integrity": "sha512-IOZfZ3nPvN6lXpJTBCunFQPRSvK8MDgSc1FB85xnIpUKOw9en0dJj8JmCAxV7BiZdtYlUpmrQjoTFkVYtdoWzQ==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", - "@radix-ui/react-dialog": "1.1.11", - "@radix-ui/react-primitive": "2.1.0", - "@radix-ui/react-slot": "1.2.0" + "@radix-ui/react-dialog": "1.1.14", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-slot": "1.2.3" }, "peerDependencies": { "@types/react": "*", @@ -657,12 +734,12 @@ } }, "node_modules/@radix-ui/react-arrow": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/@radix-ui/react-arrow/-/react-arrow-1.1.4.tgz", - "integrity": "sha512-qz+fxrqgNxG0dYew5l7qR3c7wdgRu1XVUHGnGYX7rg5HM4p9SWaRmJwfgR3J0SgyUKayLmzQIun+N6rWRgiRKw==", + "version": "1.1.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-arrow/-/react-arrow-1.1.7.tgz", + "integrity": "sha512-F+M1tLhO+mlQaOWspE8Wstg+z6PwxwRd8oQ8IXceWz92kfAmalTRf0EjrouQeo7QssEPfCn05B4Ihs1K9WQ/7w==", "license": "MIT", "dependencies": { - "@radix-ui/react-primitive": "2.1.0" + "@radix-ui/react-primitive": "2.1.3" }, "peerDependencies": { "@types/react": "*", @@ -680,12 +757,12 @@ } }, "node_modules/@radix-ui/react-aspect-ratio": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/@radix-ui/react-aspect-ratio/-/react-aspect-ratio-1.1.4.tgz", - "integrity": "sha512-ie2mUDtM38LBqVU+Xn+GIY44tWM5yVbT5uXO+th85WZxUUsgEdWNNZWecqqGzkQ4Af+Fq1mYT6TyQ/uUf5gfcw==", + "version": "1.1.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-aspect-ratio/-/react-aspect-ratio-1.1.7.tgz", + "integrity": "sha512-Yq6lvO9HQyPwev1onK1daHCHqXVLzPhSVjmsNjCa2Zcxy2f7uJD2itDtxknv6FzAKCwD1qQkeVDmX/cev13n/g==", "license": "MIT", "dependencies": { - "@radix-ui/react-primitive": "2.1.0" + "@radix-ui/react-primitive": "2.1.3" }, "peerDependencies": { "@types/react": "*", @@ -703,13 +780,13 @@ } }, "node_modules/@radix-ui/react-avatar": { - "version": "1.1.7", - "resolved": "https://registry.npmjs.org/@radix-ui/react-avatar/-/react-avatar-1.1.7.tgz", - "integrity": "sha512-V7ODUt4mUoJTe3VUxZw6nfURxaPALVqmDQh501YmaQsk3D8AZQrOPRnfKn4H7JGDLBc0KqLhT94H79nV88ppNg==", + "version": "1.1.10", + "resolved": "https://registry.npmjs.org/@radix-ui/react-avatar/-/react-avatar-1.1.10.tgz", + "integrity": "sha512-V8piFfWapM5OmNCXTzVQY+E1rDa53zY+MQ4Y7356v4fFz6vqCyUtIz2rUD44ZEdwg78/jKmMJHj07+C/Z/rcog==", "license": "MIT", "dependencies": { "@radix-ui/react-context": "1.1.2", - "@radix-ui/react-primitive": "2.1.0", + "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-callback-ref": "1.1.1", "@radix-ui/react-use-is-hydrated": "0.1.0", "@radix-ui/react-use-layout-effect": "1.1.1" @@ -730,16 +807,16 @@ } }, "node_modules/@radix-ui/react-checkbox": { - "version": "1.2.3", - "resolved": "https://registry.npmjs.org/@radix-ui/react-checkbox/-/react-checkbox-1.2.3.tgz", - "integrity": "sha512-pHVzDYsnaDmBlAuwim45y3soIN8H4R7KbkSVirGhXO+R/kO2OLCe0eucUEbddaTcdMHHdzcIGHtZSMSQlA+apw==", + "version": "1.3.2", + "resolved": "https://registry.npmjs.org/@radix-ui/react-checkbox/-/react-checkbox-1.3.2.tgz", + "integrity": "sha512-yd+dI56KZqawxKZrJ31eENUwqc1QSqg4OZ15rybGjF2ZNwMO+wCyHzAVLRp9qoYJf7kYy0YpZ2b0JCzJ42HZpA==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-presence": "1.1.4", - "@radix-ui/react-primitive": "2.1.0", + "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-previous": "1.1.1", "@radix-ui/react-use-size": "1.1.1" @@ -760,9 +837,9 @@ } }, "node_modules/@radix-ui/react-collapsible": { - "version": "1.1.8", - "resolved": "https://registry.npmjs.org/@radix-ui/react-collapsible/-/react-collapsible-1.1.8.tgz", - "integrity": "sha512-hxEsLvK9WxIAPyxdDRULL4hcaSjMZCfP7fHB0Z1uUnDoDBat1Zh46hwYfa69DeZAbJrPckjf0AGAtEZyvDyJbw==", + "version": "1.1.11", + "resolved": "https://registry.npmjs.org/@radix-ui/react-collapsible/-/react-collapsible-1.1.11.tgz", + "integrity": "sha512-2qrRsVGSCYasSz1RFOorXwl0H7g7J1frQtgpQgYrt+MOidtPAINHn9CPovQXb83r8ahapdx3Tu0fa/pdFFSdPg==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", @@ -770,7 +847,7 @@ "@radix-ui/react-context": "1.1.2", "@radix-ui/react-id": "1.1.1", "@radix-ui/react-presence": "1.1.4", - "@radix-ui/react-primitive": "2.1.0", + "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-layout-effect": "1.1.1" }, @@ -790,15 +867,15 @@ } }, "node_modules/@radix-ui/react-collection": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/@radix-ui/react-collection/-/react-collection-1.1.4.tgz", - "integrity": "sha512-cv4vSf7HttqXilDnAnvINd53OTl1/bjUYVZrkFnA7nwmY9Ob2POUy0WY0sfqBAe1s5FyKsyceQlqiEGPYNTadg==", + "version": "1.1.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-collection/-/react-collection-1.1.7.tgz", + "integrity": "sha512-Fh9rGN0MoI4ZFUNyfFVNU4y9LUz93u9/0K+yLgA2bwRojxM8JU1DyvvMBabnZPBgMWREAJvU2jjVzq+LrFUglw==", "license": "MIT", "dependencies": { "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", - "@radix-ui/react-primitive": "2.1.0", - "@radix-ui/react-slot": "1.2.0" + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-slot": "1.2.3" }, "peerDependencies": { "@types/react": "*", @@ -846,15 +923,15 @@ } }, "node_modules/@radix-ui/react-context-menu": { - "version": "2.2.12", - "resolved": "https://registry.npmjs.org/@radix-ui/react-context-menu/-/react-context-menu-2.2.12.tgz", - "integrity": "sha512-5UFKuTMX8F2/KjHvyqu9IYT8bEtDSCJwwIx1PghBo4jh9S6jJVsceq9xIjqsOVcxsynGwV5eaqPE3n/Cu+DrSA==", + "version": "2.2.15", + "resolved": "https://registry.npmjs.org/@radix-ui/react-context-menu/-/react-context-menu-2.2.15.tgz", + "integrity": "sha512-UsQUMjcYTsBjTSXw0P3GO0werEQvUY2plgRQuKoCTtkNr45q1DiL51j4m7gxhABzZ0BadoXNsIbg7F3KwiUBbw==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-context": "1.1.2", - "@radix-ui/react-menu": "2.1.12", - "@radix-ui/react-primitive": "2.1.0", + "@radix-ui/react-menu": "2.1.15", + "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-callback-ref": "1.1.1", "@radix-ui/react-use-controllable-state": "1.2.2" }, @@ -874,22 +951,22 @@ } }, "node_modules/@radix-ui/react-dialog": { - "version": "1.1.11", - "resolved": "https://registry.npmjs.org/@radix-ui/react-dialog/-/react-dialog-1.1.11.tgz", - "integrity": "sha512-yI7S1ipkP5/+99qhSI6nthfo/tR6bL6Zgxi/+1UO6qPa6UeM6nlafWcQ65vB4rU2XjgjMfMhI3k9Y5MztA62VQ==", + "version": "1.1.14", + "resolved": "https://registry.npmjs.org/@radix-ui/react-dialog/-/react-dialog-1.1.14.tgz", + "integrity": "sha512-+CpweKjqpzTmwRwcYECQcNYbI8V9VSQt0SNFKeEBLgfucbsLssU6Ppq7wUdNXEGb573bMjFhVjKVll8rmV6zMw==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", - "@radix-ui/react-dismissable-layer": "1.1.7", + "@radix-ui/react-dismissable-layer": "1.1.10", "@radix-ui/react-focus-guards": "1.1.2", - "@radix-ui/react-focus-scope": "1.1.4", + "@radix-ui/react-focus-scope": "1.1.7", "@radix-ui/react-id": "1.1.1", - "@radix-ui/react-portal": "1.1.6", + "@radix-ui/react-portal": "1.1.9", "@radix-ui/react-presence": "1.1.4", - "@radix-ui/react-primitive": "2.1.0", - "@radix-ui/react-slot": "1.2.0", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-slot": "1.2.3", "@radix-ui/react-use-controllable-state": "1.2.2", "aria-hidden": "^1.2.4", "react-remove-scroll": "^2.6.3" @@ -925,14 +1002,14 @@ } }, "node_modules/@radix-ui/react-dismissable-layer": { - "version": "1.1.7", - "resolved": "https://registry.npmjs.org/@radix-ui/react-dismissable-layer/-/react-dismissable-layer-1.1.7.tgz", - "integrity": "sha512-j5+WBUdhccJsmH5/H0K6RncjDtoALSEr6jbkaZu+bjw6hOPOhHycr6vEUujl+HBK8kjUfWcoCJXxP6e4lUlMZw==", + "version": "1.1.10", + "resolved": "https://registry.npmjs.org/@radix-ui/react-dismissable-layer/-/react-dismissable-layer-1.1.10.tgz", + "integrity": "sha512-IM1zzRV4W3HtVgftdQiiOmA0AdJlCtMLe00FXaHwgt3rAnNsIyDqshvkIW3hj/iu5hu8ERP7KIYki6NkqDxAwQ==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-compose-refs": "1.1.2", - "@radix-ui/react-primitive": "2.1.0", + "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-callback-ref": "1.1.1", "@radix-ui/react-use-escape-keydown": "1.1.1" }, @@ -952,17 +1029,17 @@ } }, "node_modules/@radix-ui/react-dropdown-menu": { - "version": "2.1.12", - "resolved": "https://registry.npmjs.org/@radix-ui/react-dropdown-menu/-/react-dropdown-menu-2.1.12.tgz", - "integrity": "sha512-VJoMs+BWWE7YhzEQyVwvF9n22Eiyr83HotCVrMQzla/OwRovXCgah7AcaEr4hMNj4gJxSdtIbcHGvmJXOoJVHA==", + "version": "2.1.15", + "resolved": "https://registry.npmjs.org/@radix-ui/react-dropdown-menu/-/react-dropdown-menu-2.1.15.tgz", + "integrity": "sha512-mIBnOjgwo9AH3FyKaSWoSu/dYj6VdhJ7frEPiGTeXCdUFHjl9h3mFh2wwhEtINOmYXWhdpf1rY2minFsmaNgVQ==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-id": "1.1.1", - "@radix-ui/react-menu": "2.1.12", - "@radix-ui/react-primitive": "2.1.0", + "@radix-ui/react-menu": "2.1.15", + "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-controllable-state": "1.2.2" }, "peerDependencies": { @@ -996,13 +1073,13 @@ } }, "node_modules/@radix-ui/react-focus-scope": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/@radix-ui/react-focus-scope/-/react-focus-scope-1.1.4.tgz", - "integrity": "sha512-r2annK27lIW5w9Ho5NyQgqs0MmgZSTIKXWpVCJaLC1q2kZrZkcqnmHkCHMEmv8XLvsLlurKMPT+kbKkRkm/xVA==", + "version": "1.1.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-focus-scope/-/react-focus-scope-1.1.7.tgz", + "integrity": "sha512-t2ODlkXBQyn7jkl6TNaw/MtVEVvIGelJDCG41Okq/KwUsJBwQ4XVZsHAVUkK4mBv3ewiAS3PGuUWuY2BoK4ZUw==", "license": "MIT", "dependencies": { "@radix-ui/react-compose-refs": "1.1.2", - "@radix-ui/react-primitive": "2.1.0", + "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-callback-ref": "1.1.1" }, "peerDependencies": { @@ -1021,17 +1098,17 @@ } }, "node_modules/@radix-ui/react-form": { - "version": "0.1.4", - "resolved": "https://registry.npmjs.org/@radix-ui/react-form/-/react-form-0.1.4.tgz", - "integrity": "sha512-97Q7Hb0///sMF2X8XvyVx3Aub7WG/ybIofoDVUo8utG/z/6TBzWGjgai7ZjECXYLbKip88t9/ibyQJvYe5k6SA==", + "version": "0.1.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-form/-/react-form-0.1.7.tgz", + "integrity": "sha512-IXLKFnaYvFg/KkeV5QfOX7tRnwHXp127koOFUjLWMTrRv5Rny3DQcAtIFFeA/Cli4HHM8DuJCXAUsgnFVJndlw==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-id": "1.1.1", - "@radix-ui/react-label": "2.1.4", - "@radix-ui/react-primitive": "2.1.0" + "@radix-ui/react-label": "2.1.7", + "@radix-ui/react-primitive": "2.1.3" }, "peerDependencies": { "@types/react": "*", @@ -1049,19 +1126,19 @@ } }, "node_modules/@radix-ui/react-hover-card": { - "version": "1.1.11", - "resolved": "https://registry.npmjs.org/@radix-ui/react-hover-card/-/react-hover-card-1.1.11.tgz", - "integrity": "sha512-q9h9grUpGZKR3MNhtVCLVnPGmx1YnzBgGR+O40mhSNGsUnkR+LChVH8c7FB0mkS+oudhd8KAkZGTJPJCjdAPIg==", + "version": "1.1.14", + "resolved": "https://registry.npmjs.org/@radix-ui/react-hover-card/-/react-hover-card-1.1.14.tgz", + "integrity": "sha512-CPYZ24Mhirm+g6D8jArmLzjYu4Eyg3TTUHswR26QgzXBHBe64BO/RHOJKzmF/Dxb4y4f9PKyJdwm/O/AhNkb+Q==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", - "@radix-ui/react-dismissable-layer": "1.1.7", - "@radix-ui/react-popper": "1.2.4", - "@radix-ui/react-portal": "1.1.6", + "@radix-ui/react-dismissable-layer": "1.1.10", + "@radix-ui/react-popper": "1.2.7", + "@radix-ui/react-portal": "1.1.9", "@radix-ui/react-presence": "1.1.4", - "@radix-ui/react-primitive": "2.1.0", + "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-controllable-state": "1.2.2" }, "peerDependencies": { @@ -1107,12 +1184,12 @@ } }, "node_modules/@radix-ui/react-label": { - "version": "2.1.4", - "resolved": "https://registry.npmjs.org/@radix-ui/react-label/-/react-label-2.1.4.tgz", - "integrity": "sha512-wy3dqizZnZVV4ja0FNnUhIWNwWdoldXrneEyUcVtLYDAt8ovGS4ridtMAOGgXBBIfggL4BOveVWsjXDORdGEQg==", + "version": "2.1.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-label/-/react-label-2.1.7.tgz", + "integrity": "sha512-YT1GqPSL8kJn20djelMX7/cTRp/Y9w5IZHvfxQTVHrOqa2yMl7i/UfMqKRU5V7mEyKTrUVgJXhNQPVCG8PBLoQ==", "license": "MIT", "dependencies": { - "@radix-ui/react-primitive": "2.1.0" + "@radix-ui/react-primitive": "2.1.3" }, "peerDependencies": { "@types/react": "*", @@ -1130,26 +1207,26 @@ } }, "node_modules/@radix-ui/react-menu": { - "version": "2.1.12", - "resolved": "https://registry.npmjs.org/@radix-ui/react-menu/-/react-menu-2.1.12.tgz", - "integrity": "sha512-+qYq6LfbiGo97Zz9fioX83HCiIYYFNs8zAsVCMQrIakoNYylIzWuoD/anAD3UzvvR6cnswmfRFJFq/zYYq/k7Q==", + "version": "2.1.15", + "resolved": "https://registry.npmjs.org/@radix-ui/react-menu/-/react-menu-2.1.15.tgz", + "integrity": "sha512-tVlmA3Vb9n8SZSd+YSbuFR66l87Wiy4du+YE+0hzKQEANA+7cWKH1WgqcEX4pXqxUFQKrWQGHdvEfw00TjFiew==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", - "@radix-ui/react-collection": "1.1.4", + "@radix-ui/react-collection": "1.1.7", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", - "@radix-ui/react-dismissable-layer": "1.1.7", + "@radix-ui/react-dismissable-layer": "1.1.10", "@radix-ui/react-focus-guards": "1.1.2", - "@radix-ui/react-focus-scope": "1.1.4", + "@radix-ui/react-focus-scope": "1.1.7", "@radix-ui/react-id": "1.1.1", - "@radix-ui/react-popper": "1.2.4", - "@radix-ui/react-portal": "1.1.6", + "@radix-ui/react-popper": "1.2.7", + "@radix-ui/react-portal": "1.1.9", "@radix-ui/react-presence": "1.1.4", - "@radix-ui/react-primitive": "2.1.0", - "@radix-ui/react-roving-focus": "1.1.7", - "@radix-ui/react-slot": "1.2.0", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-roving-focus": "1.1.10", + "@radix-ui/react-slot": "1.2.3", "@radix-ui/react-use-callback-ref": "1.1.1", "aria-hidden": "^1.2.4", "react-remove-scroll": "^2.6.3" @@ -1170,20 +1247,20 @@ } }, "node_modules/@radix-ui/react-menubar": { - "version": "1.1.12", - "resolved": "https://registry.npmjs.org/@radix-ui/react-menubar/-/react-menubar-1.1.12.tgz", - "integrity": "sha512-bM2vT5nxRqJH/d1vFQ9jLsW4qR70yFQw2ZD1TUPWUNskDsV0eYeMbbNJqxNjGMOVogEkOJaHtu11kzYdTJvVJg==", + "version": "1.1.15", + "resolved": "https://registry.npmjs.org/@radix-ui/react-menubar/-/react-menubar-1.1.15.tgz", + "integrity": "sha512-Z71C7LGD+YDYo3TV81paUs8f3Zbmkvg6VLRQpKYfzioOE6n7fOhA3ApK/V/2Odolxjoc4ENk8AYCjohCNayd5A==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", - "@radix-ui/react-collection": "1.1.4", + "@radix-ui/react-collection": "1.1.7", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", "@radix-ui/react-id": "1.1.1", - "@radix-ui/react-menu": "2.1.12", - "@radix-ui/react-primitive": "2.1.0", - "@radix-ui/react-roving-focus": "1.1.7", + "@radix-ui/react-menu": "2.1.15", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-roving-focus": "1.1.10", "@radix-ui/react-use-controllable-state": "1.2.2" }, "peerDependencies": { @@ -1202,25 +1279,25 @@ } }, "node_modules/@radix-ui/react-navigation-menu": { - "version": "1.2.10", - "resolved": "https://registry.npmjs.org/@radix-ui/react-navigation-menu/-/react-navigation-menu-1.2.10.tgz", - "integrity": "sha512-kGDqMVPj2SRB1vJmXN/jnhC66REAXNyDmDRubbbmJ+360zSIJUDmWGMKIJOf72PHMwPENrbtJVb3CMAUJDjEIA==", + "version": "1.2.13", + "resolved": "https://registry.npmjs.org/@radix-ui/react-navigation-menu/-/react-navigation-menu-1.2.13.tgz", + "integrity": "sha512-WG8wWfDiJlSF5hELjwfjSGOXcBR/ZMhBFCGYe8vERpC39CQYZeq1PQ2kaYHdye3V95d06H89KGMsVCIE4LWo3g==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", - "@radix-ui/react-collection": "1.1.4", + "@radix-ui/react-collection": "1.1.7", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", - "@radix-ui/react-dismissable-layer": "1.1.7", + "@radix-ui/react-dismissable-layer": "1.1.10", "@radix-ui/react-id": "1.1.1", "@radix-ui/react-presence": "1.1.4", - "@radix-ui/react-primitive": "2.1.0", + "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-callback-ref": "1.1.1", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-layout-effect": "1.1.1", "@radix-ui/react-use-previous": "1.1.1", - "@radix-ui/react-visually-hidden": "1.2.0" + "@radix-ui/react-visually-hidden": "1.2.3" }, "peerDependencies": { "@types/react": "*", @@ -1238,19 +1315,19 @@ } }, "node_modules/@radix-ui/react-one-time-password-field": { - "version": "0.1.4", - "resolved": "https://registry.npmjs.org/@radix-ui/react-one-time-password-field/-/react-one-time-password-field-0.1.4.tgz", - "integrity": "sha512-CygYLHY8kO1De5iAZBn7gQbIoRNVGYx1paIyqbmwlxP6DF7sF1LLW3chXo/qxc4IWUQnsgAhfl9u6IoLXTndqQ==", + "version": "0.1.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-one-time-password-field/-/react-one-time-password-field-0.1.7.tgz", + "integrity": "sha512-w1vm7AGI8tNXVovOK7TYQHrAGpRF7qQL+ENpT1a743De5Zmay2RbWGKAiYDKIyIuqptns+znCKwNztE2xl1n0Q==", "license": "MIT", "dependencies": { "@radix-ui/number": "1.1.1", "@radix-ui/primitive": "1.1.2", - "@radix-ui/react-collection": "1.1.4", + "@radix-ui/react-collection": "1.1.7", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", - "@radix-ui/react-primitive": "2.1.0", - "@radix-ui/react-roving-focus": "1.1.7", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-roving-focus": "1.1.10", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-effect-event": "0.0.2", "@radix-ui/react-use-is-hydrated": "0.1.0", @@ -1271,24 +1348,54 @@ } } }, + "node_modules/@radix-ui/react-password-toggle-field": { + "version": "0.1.2", + "resolved": "https://registry.npmjs.org/@radix-ui/react-password-toggle-field/-/react-password-toggle-field-0.1.2.tgz", + "integrity": "sha512-F90uYnlBsLPU1UbSLciLsWQmk8+hdWa6SFw4GXaIdNWxFxI5ITKVdAG64f+Twaa9ic6xE7pqxPyUmodrGjT4pQ==", + "license": "MIT", + "dependencies": { + "@radix-ui/primitive": "1.1.2", + "@radix-ui/react-compose-refs": "1.1.2", + "@radix-ui/react-context": "1.1.2", + "@radix-ui/react-id": "1.1.1", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-use-controllable-state": "1.2.2", + "@radix-ui/react-use-effect-event": "0.0.2", + "@radix-ui/react-use-is-hydrated": "0.1.0" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, "node_modules/@radix-ui/react-popover": { - "version": "1.1.11", - "resolved": "https://registry.npmjs.org/@radix-ui/react-popover/-/react-popover-1.1.11.tgz", - "integrity": "sha512-yFMfZkVA5G3GJnBgb2PxrrcLKm1ZLWXrbYVgdyTl//0TYEIHS9LJbnyz7WWcZ0qCq7hIlJZpRtxeSeIG5T5oJw==", + "version": "1.1.14", + "resolved": "https://registry.npmjs.org/@radix-ui/react-popover/-/react-popover-1.1.14.tgz", + "integrity": "sha512-ODz16+1iIbGUfFEfKx2HTPKizg2MN39uIOV8MXeHnmdd3i/N9Wt7vU46wbHsqA0xoaQyXVcs0KIlBdOA2Y95bw==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", - "@radix-ui/react-dismissable-layer": "1.1.7", + "@radix-ui/react-dismissable-layer": "1.1.10", "@radix-ui/react-focus-guards": "1.1.2", - "@radix-ui/react-focus-scope": "1.1.4", + "@radix-ui/react-focus-scope": "1.1.7", "@radix-ui/react-id": "1.1.1", - "@radix-ui/react-popper": "1.2.4", - "@radix-ui/react-portal": "1.1.6", + "@radix-ui/react-popper": "1.2.7", + "@radix-ui/react-portal": "1.1.9", "@radix-ui/react-presence": "1.1.4", - "@radix-ui/react-primitive": "2.1.0", - "@radix-ui/react-slot": "1.2.0", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-slot": "1.2.3", "@radix-ui/react-use-controllable-state": "1.2.2", "aria-hidden": "^1.2.4", "react-remove-scroll": "^2.6.3" @@ -1309,16 +1416,16 @@ } }, "node_modules/@radix-ui/react-popper": { - "version": "1.2.4", - "resolved": "https://registry.npmjs.org/@radix-ui/react-popper/-/react-popper-1.2.4.tgz", - "integrity": "sha512-3p2Rgm/a1cK0r/UVkx5F/K9v/EplfjAeIFCGOPYPO4lZ0jtg4iSQXt/YGTSLWaf4x7NG6Z4+uKFcylcTZjeqDA==", + "version": "1.2.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-popper/-/react-popper-1.2.7.tgz", + "integrity": "sha512-IUFAccz1JyKcf/RjB552PlWwxjeCJB8/4KxT7EhBHOJM+mN7LdW+B3kacJXILm32xawcMMjb2i0cIZpo+f9kiQ==", "license": "MIT", "dependencies": { "@floating-ui/react-dom": "^2.0.0", - "@radix-ui/react-arrow": "1.1.4", + "@radix-ui/react-arrow": "1.1.7", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", - "@radix-ui/react-primitive": "2.1.0", + "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-callback-ref": "1.1.1", "@radix-ui/react-use-layout-effect": "1.1.1", "@radix-ui/react-use-rect": "1.1.1", @@ -1341,12 +1448,12 @@ } }, "node_modules/@radix-ui/react-portal": { - "version": "1.1.6", - "resolved": "https://registry.npmjs.org/@radix-ui/react-portal/-/react-portal-1.1.6.tgz", - "integrity": "sha512-XmsIl2z1n/TsYFLIdYam2rmFwf9OC/Sh2avkbmVMDuBZIe7hSpM0cYnWPAo7nHOVx8zTuwDZGByfcqLdnzp3Vw==", + "version": "1.1.9", + "resolved": "https://registry.npmjs.org/@radix-ui/react-portal/-/react-portal-1.1.9.tgz", + "integrity": "sha512-bpIxvq03if6UNwXZ+HTK71JLh4APvnXntDc6XOX8UVq4XQOVl7lwok0AvIl+b8zgCw3fSaVTZMpAPPagXbKmHQ==", "license": "MIT", "dependencies": { - "@radix-ui/react-primitive": "2.1.0", + "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-layout-effect": "1.1.1" }, "peerDependencies": { @@ -1389,12 +1496,12 @@ } }, "node_modules/@radix-ui/react-primitive": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/@radix-ui/react-primitive/-/react-primitive-2.1.0.tgz", - "integrity": "sha512-/J/FhLdK0zVcILOwt5g+dH4KnkonCtkVJsa2G6JmvbbtZfBEI1gMsO3QMjseL4F/SwfAMt1Vc/0XKYKq+xJ1sw==", + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-primitive/-/react-primitive-2.1.3.tgz", + "integrity": "sha512-m9gTwRkhy2lvCPe6QJp4d3G1TYEUHn/FzJUtq9MjH46an1wJU+GdoGC5VLof8RX8Ft/DlpshApkhswDLZzHIcQ==", "license": "MIT", "dependencies": { - "@radix-ui/react-slot": "1.2.0" + "@radix-ui/react-slot": "1.2.3" }, "peerDependencies": { "@types/react": "*", @@ -1412,13 +1519,13 @@ } }, "node_modules/@radix-ui/react-progress": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/@radix-ui/react-progress/-/react-progress-1.1.4.tgz", - "integrity": "sha512-8rl9w7lJdcVPor47Dhws9mUHRHLE+8JEgyJRdNWCpGPa6HIlr3eh+Yn9gyx1CnCLbw5naHsI2gaO9dBWO50vzw==", + "version": "1.1.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-progress/-/react-progress-1.1.7.tgz", + "integrity": "sha512-vPdg/tF6YC/ynuBIJlk1mm7Le0VgW6ub6J2UWnTQ7/D23KXcPI1qy+0vBkgKgd38RCMJavBXpB83HPNFMTb0Fg==", "license": "MIT", "dependencies": { "@radix-ui/react-context": "1.1.2", - "@radix-ui/react-primitive": "2.1.0" + "@radix-ui/react-primitive": "2.1.3" }, "peerDependencies": { "@types/react": "*", @@ -1436,9 +1543,9 @@ } }, "node_modules/@radix-ui/react-radio-group": { - "version": "1.3.4", - "resolved": "https://registry.npmjs.org/@radix-ui/react-radio-group/-/react-radio-group-1.3.4.tgz", - "integrity": "sha512-N4J9QFdW5zcJNxxY/zwTXBN4Uc5VEuRM7ZLjNfnWoKmNvgrPtNNw4P8zY532O3qL6aPkaNO+gY9y6bfzmH4U1g==", + "version": "1.3.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-radio-group/-/react-radio-group-1.3.7.tgz", + "integrity": "sha512-9w5XhD0KPOrm92OTTE0SysH3sYzHsSTHNvZgUBo/VZ80VdYyB5RneDbc0dKpURS24IxkoFRu/hI0i4XyfFwY6g==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", @@ -1446,8 +1553,8 @@ "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", "@radix-ui/react-presence": "1.1.4", - "@radix-ui/react-primitive": "2.1.0", - "@radix-ui/react-roving-focus": "1.1.7", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-roving-focus": "1.1.10", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-previous": "1.1.1", "@radix-ui/react-use-size": "1.1.1" @@ -1468,18 +1575,18 @@ } }, "node_modules/@radix-ui/react-roving-focus": { - "version": "1.1.7", - "resolved": "https://registry.npmjs.org/@radix-ui/react-roving-focus/-/react-roving-focus-1.1.7.tgz", - "integrity": "sha512-C6oAg451/fQT3EGbWHbCQjYTtbyjNO1uzQgMzwyivcHT3GKNEmu1q3UuREhN+HzHAVtv3ivMVK08QlC+PkYw9Q==", + "version": "1.1.10", + "resolved": "https://registry.npmjs.org/@radix-ui/react-roving-focus/-/react-roving-focus-1.1.10.tgz", + "integrity": "sha512-dT9aOXUen9JSsxnMPv/0VqySQf5eDQ6LCk5Sw28kamz8wSOW2bJdlX2Bg5VUIIcV+6XlHpWTIuTPCf/UNIyq8Q==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", - "@radix-ui/react-collection": "1.1.4", + "@radix-ui/react-collection": "1.1.7", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", "@radix-ui/react-id": "1.1.1", - "@radix-ui/react-primitive": "2.1.0", + "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-callback-ref": "1.1.1", "@radix-ui/react-use-controllable-state": "1.2.2" }, @@ -1499,9 +1606,9 @@ } }, "node_modules/@radix-ui/react-scroll-area": { - "version": "1.2.6", - "resolved": "https://registry.npmjs.org/@radix-ui/react-scroll-area/-/react-scroll-area-1.2.6.tgz", - "integrity": "sha512-lj8OMlpPERXrQIHlEQdlXHJoRT52AMpBrgyPYylOhXYq5e/glsEdtOc/kCQlsTdtgN5U0iDbrrolDadvektJGQ==", + "version": "1.2.9", + "resolved": "https://registry.npmjs.org/@radix-ui/react-scroll-area/-/react-scroll-area-1.2.9.tgz", + "integrity": "sha512-YSjEfBXnhUELsO2VzjdtYYD4CfQjvao+lhhrX5XsHD7/cyUNzljF1FHEbgTPN7LH2MClfwRMIsYlqTYpKTTe2A==", "license": "MIT", "dependencies": { "@radix-ui/number": "1.1.1", @@ -1510,7 +1617,7 @@ "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", "@radix-ui/react-presence": "1.1.4", - "@radix-ui/react-primitive": "2.1.0", + "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-callback-ref": "1.1.1", "@radix-ui/react-use-layout-effect": "1.1.1" }, @@ -1530,30 +1637,30 @@ } }, "node_modules/@radix-ui/react-select": { - "version": "2.2.2", - "resolved": "https://registry.npmjs.org/@radix-ui/react-select/-/react-select-2.2.2.tgz", - "integrity": "sha512-HjkVHtBkuq+r3zUAZ/CvNWUGKPfuicGDbgtZgiQuFmNcV5F+Tgy24ep2nsAW2nFgvhGPJVqeBZa6KyVN0EyrBA==", + "version": "2.2.5", + "resolved": "https://registry.npmjs.org/@radix-ui/react-select/-/react-select-2.2.5.tgz", + "integrity": "sha512-HnMTdXEVuuyzx63ME0ut4+sEMYW6oouHWNGUZc7ddvUWIcfCva/AMoqEW/3wnEllriMWBa0RHspCYnfCWJQYmA==", "license": "MIT", "dependencies": { "@radix-ui/number": "1.1.1", "@radix-ui/primitive": "1.1.2", - "@radix-ui/react-collection": "1.1.4", + "@radix-ui/react-collection": "1.1.7", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", - "@radix-ui/react-dismissable-layer": "1.1.7", + "@radix-ui/react-dismissable-layer": "1.1.10", "@radix-ui/react-focus-guards": "1.1.2", - "@radix-ui/react-focus-scope": "1.1.4", + "@radix-ui/react-focus-scope": "1.1.7", "@radix-ui/react-id": "1.1.1", - "@radix-ui/react-popper": "1.2.4", - "@radix-ui/react-portal": "1.1.6", - "@radix-ui/react-primitive": "2.1.0", - "@radix-ui/react-slot": "1.2.0", + "@radix-ui/react-popper": "1.2.7", + "@radix-ui/react-portal": "1.1.9", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-slot": "1.2.3", "@radix-ui/react-use-callback-ref": "1.1.1", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-layout-effect": "1.1.1", "@radix-ui/react-use-previous": "1.1.1", - "@radix-ui/react-visually-hidden": "1.2.0", + "@radix-ui/react-visually-hidden": "1.2.3", "aria-hidden": "^1.2.4", "react-remove-scroll": "^2.6.3" }, @@ -1573,12 +1680,12 @@ } }, "node_modules/@radix-ui/react-separator": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/@radix-ui/react-separator/-/react-separator-1.1.4.tgz", - "integrity": "sha512-2fTm6PSiUm8YPq9W0E4reYuv01EE3aFSzt8edBiXqPHshF8N9+Kymt/k0/R+F3dkY5lQyB/zPtrP82phskLi7w==", + "version": "1.1.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-separator/-/react-separator-1.1.7.tgz", + "integrity": "sha512-0HEb8R9E8A+jZjvmFCy/J4xhbXy3TV+9XSnGJ3KvTtjlIUy/YQ/p6UYZvi7YbeoeXdyU9+Y3scizK6hkY37baA==", "license": "MIT", "dependencies": { - "@radix-ui/react-primitive": "2.1.0" + "@radix-ui/react-primitive": "2.1.3" }, "peerDependencies": { "@types/react": "*", @@ -1596,18 +1703,18 @@ } }, "node_modules/@radix-ui/react-slider": { - "version": "1.3.2", - "resolved": "https://registry.npmjs.org/@radix-ui/react-slider/-/react-slider-1.3.2.tgz", - "integrity": "sha512-oQnqfgSiYkxZ1MrF6672jw2/zZvpB+PJsrIc3Zm1zof1JHf/kj7WhmROw7JahLfOwYQ5/+Ip0rFORgF1tjSiaQ==", + "version": "1.3.5", + "resolved": "https://registry.npmjs.org/@radix-ui/react-slider/-/react-slider-1.3.5.tgz", + "integrity": "sha512-rkfe2pU2NBAYfGaxa3Mqosi7VZEWX5CxKaanRv0vZd4Zhl9fvQrg0VM93dv3xGLGfrHuoTRF3JXH8nb9g+B3fw==", "license": "MIT", "dependencies": { "@radix-ui/number": "1.1.1", "@radix-ui/primitive": "1.1.2", - "@radix-ui/react-collection": "1.1.4", + "@radix-ui/react-collection": "1.1.7", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", - "@radix-ui/react-primitive": "2.1.0", + "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-layout-effect": "1.1.1", "@radix-ui/react-use-previous": "1.1.1", @@ -1629,9 +1736,9 @@ } }, "node_modules/@radix-ui/react-slot": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.2.0.tgz", - "integrity": "sha512-ujc+V6r0HNDviYqIK3rW4ffgYiZ8g5DEHrGJVk4x7kTlLXRDILnKX9vAUYeIsLOoDpDJ0ujpqMkjH4w2ofuo6w==", + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.2.3.tgz", + "integrity": "sha512-aeNmHnBxbi2St0au6VBVC7JXFlhLlOnvIIlePNniyUNAClzmtAUEY8/pBiK3iHjufOlwA+c20/8jngo7xcrg8A==", "license": "MIT", "dependencies": { "@radix-ui/react-compose-refs": "1.1.2" @@ -1647,15 +1754,15 @@ } }, "node_modules/@radix-ui/react-switch": { - "version": "1.2.2", - "resolved": "https://registry.npmjs.org/@radix-ui/react-switch/-/react-switch-1.2.2.tgz", - "integrity": "sha512-7Z8n6L+ifMIIYZ83f28qWSceUpkXuslI2FJ34+kDMTiyj91ENdpdQ7VCidrzj5JfwfZTeano/BnGBbu/jqa5rQ==", + "version": "1.2.5", + "resolved": "https://registry.npmjs.org/@radix-ui/react-switch/-/react-switch-1.2.5.tgz", + "integrity": "sha512-5ijLkak6ZMylXsaImpZ8u4Rlf5grRmoc0p0QeX9VJtlrM4f5m3nCTX8tWga/zOA8PZYIR/t0p2Mnvd7InrJ6yQ==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", - "@radix-ui/react-primitive": "2.1.0", + "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-previous": "1.1.1", "@radix-ui/react-use-size": "1.1.1" @@ -1676,9 +1783,9 @@ } }, "node_modules/@radix-ui/react-tabs": { - "version": "1.1.9", - "resolved": "https://registry.npmjs.org/@radix-ui/react-tabs/-/react-tabs-1.1.9.tgz", - "integrity": "sha512-KIjtwciYvquiW/wAFkELZCVnaNLBsYNhTNcvl+zfMAbMhRkcvNuCLXDDd22L0j7tagpzVh/QwbFpwAATg7ILPw==", + "version": "1.1.12", + "resolved": "https://registry.npmjs.org/@radix-ui/react-tabs/-/react-tabs-1.1.12.tgz", + "integrity": "sha512-GTVAlRVrQrSw3cEARM0nAx73ixrWDPNZAruETn3oHCNP6SbZ/hNxdxp+u7VkIEv3/sFoLq1PfcHrl7Pnp0CDpw==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", @@ -1686,8 +1793,8 @@ "@radix-ui/react-direction": "1.1.1", "@radix-ui/react-id": "1.1.1", "@radix-ui/react-presence": "1.1.4", - "@radix-ui/react-primitive": "2.1.0", - "@radix-ui/react-roving-focus": "1.1.7", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-roving-focus": "1.1.10", "@radix-ui/react-use-controllable-state": "1.2.2" }, "peerDependencies": { @@ -1706,23 +1813,23 @@ } }, "node_modules/@radix-ui/react-toast": { - "version": "1.2.11", - "resolved": "https://registry.npmjs.org/@radix-ui/react-toast/-/react-toast-1.2.11.tgz", - "integrity": "sha512-Ed2mlOmT+tktOsu2NZBK1bCSHh/uqULu1vWOkpQTVq53EoOuZUZw7FInQoDB3uil5wZc2oe0XN9a7uVZB7/6AQ==", + "version": "1.2.14", + "resolved": "https://registry.npmjs.org/@radix-ui/react-toast/-/react-toast-1.2.14.tgz", + "integrity": "sha512-nAP5FBxBJGQ/YfUB+r+O6USFVkWq3gAInkxyEnmvEV5jtSbfDhfa4hwX8CraCnbjMLsE7XSf/K75l9xXY7joWg==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", - "@radix-ui/react-collection": "1.1.4", + "@radix-ui/react-collection": "1.1.7", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", - "@radix-ui/react-dismissable-layer": "1.1.7", - "@radix-ui/react-portal": "1.1.6", + "@radix-ui/react-dismissable-layer": "1.1.10", + "@radix-ui/react-portal": "1.1.9", "@radix-ui/react-presence": "1.1.4", - "@radix-ui/react-primitive": "2.1.0", + "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-callback-ref": "1.1.1", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-layout-effect": "1.1.1", - "@radix-ui/react-visually-hidden": "1.2.0" + "@radix-ui/react-visually-hidden": "1.2.3" }, "peerDependencies": { "@types/react": "*", @@ -1740,13 +1847,13 @@ } }, "node_modules/@radix-ui/react-toggle": { - "version": "1.1.6", - "resolved": "https://registry.npmjs.org/@radix-ui/react-toggle/-/react-toggle-1.1.6.tgz", - "integrity": "sha512-3SeJxKeO3TO1zVw1Nl++Cp0krYk6zHDHMCUXXVkosIzl6Nxcvb07EerQpyD2wXQSJ5RZajrYAmPaydU8Hk1IyQ==", + "version": "1.1.9", + "resolved": "https://registry.npmjs.org/@radix-ui/react-toggle/-/react-toggle-1.1.9.tgz", + "integrity": "sha512-ZoFkBBz9zv9GWer7wIjvdRxmh2wyc2oKWw6C6CseWd6/yq1DK/l5lJ+wnsmFwJZbBYqr02mrf8A2q/CVCuM3ZA==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", - "@radix-ui/react-primitive": "2.1.0", + "@radix-ui/react-primitive": "2.1.3", "@radix-ui/react-use-controllable-state": "1.2.2" }, "peerDependencies": { @@ -1765,17 +1872,17 @@ } }, "node_modules/@radix-ui/react-toggle-group": { - "version": "1.1.7", - "resolved": "https://registry.npmjs.org/@radix-ui/react-toggle-group/-/react-toggle-group-1.1.7.tgz", - "integrity": "sha512-GRaPJhxrRSOqAcmcX3MwRL/SZACkoYdmoY9/sg7Bd5DhBYsB2t4co0NxTvVW8H7jUmieQDQwRtUlZ5Ta8UbgJA==", + "version": "1.1.10", + "resolved": "https://registry.npmjs.org/@radix-ui/react-toggle-group/-/react-toggle-group-1.1.10.tgz", + "integrity": "sha512-kiU694Km3WFLTC75DdqgM/3Jauf3rD9wxeS9XtyWFKsBUeZA337lC+6uUazT7I1DhanZ5gyD5Stf8uf2dbQxOQ==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", - "@radix-ui/react-primitive": "2.1.0", - "@radix-ui/react-roving-focus": "1.1.7", - "@radix-ui/react-toggle": "1.1.6", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-roving-focus": "1.1.10", + "@radix-ui/react-toggle": "1.1.9", "@radix-ui/react-use-controllable-state": "1.2.2" }, "peerDependencies": { @@ -1794,18 +1901,18 @@ } }, "node_modules/@radix-ui/react-toolbar": { - "version": "1.1.7", - "resolved": "https://registry.npmjs.org/@radix-ui/react-toolbar/-/react-toolbar-1.1.7.tgz", - "integrity": "sha512-cL/3snRskM0f955waP+m4Pmr8+QOPpPsfoY5kM06k7eWP41diOcyjLEqSxpd/K9S7fpsV66yq4R6yN2sMwXc6Q==", + "version": "1.1.10", + "resolved": "https://registry.npmjs.org/@radix-ui/react-toolbar/-/react-toolbar-1.1.10.tgz", + "integrity": "sha512-jiwQsduEL++M4YBIurjSa+voD86OIytCod0/dbIxFZDLD8NfO1//keXYMfsW8BPcfqwoNjt+y06XcJqAb4KR7A==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-context": "1.1.2", "@radix-ui/react-direction": "1.1.1", - "@radix-ui/react-primitive": "2.1.0", - "@radix-ui/react-roving-focus": "1.1.7", - "@radix-ui/react-separator": "1.1.4", - "@radix-ui/react-toggle-group": "1.1.7" + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-roving-focus": "1.1.10", + "@radix-ui/react-separator": "1.1.7", + "@radix-ui/react-toggle-group": "1.1.10" }, "peerDependencies": { "@types/react": "*", @@ -1823,23 +1930,23 @@ } }, "node_modules/@radix-ui/react-tooltip": { - "version": "1.2.4", - "resolved": "https://registry.npmjs.org/@radix-ui/react-tooltip/-/react-tooltip-1.2.4.tgz", - "integrity": "sha512-DyW8VVeeMSSLFvAmnVnCwvI3H+1tpJFHT50r+tdOoMse9XqYDBCcyux8u3G2y+LOpt7fPQ6KKH0mhs+ce1+Z5w==", + "version": "1.2.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-tooltip/-/react-tooltip-1.2.7.tgz", + "integrity": "sha512-Ap+fNYwKTYJ9pzqW+Xe2HtMRbQ/EeWkj2qykZ6SuEV4iS/o1bZI5ssJbk4D2r8XuDuOBVz/tIx2JObtuqU+5Zw==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", - "@radix-ui/react-dismissable-layer": "1.1.7", + "@radix-ui/react-dismissable-layer": "1.1.10", "@radix-ui/react-id": "1.1.1", - "@radix-ui/react-popper": "1.2.4", - "@radix-ui/react-portal": "1.1.6", + "@radix-ui/react-popper": "1.2.7", + "@radix-ui/react-portal": "1.1.9", "@radix-ui/react-presence": "1.1.4", - "@radix-ui/react-primitive": "2.1.0", - "@radix-ui/react-slot": "1.2.0", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-slot": "1.2.3", "@radix-ui/react-use-controllable-state": "1.2.2", - "@radix-ui/react-visually-hidden": "1.2.0" + "@radix-ui/react-visually-hidden": "1.2.3" }, "peerDependencies": { "@types/react": "*", @@ -2011,12 +2118,12 @@ } }, "node_modules/@radix-ui/react-visually-hidden": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/@radix-ui/react-visually-hidden/-/react-visually-hidden-1.2.0.tgz", - "integrity": "sha512-rQj0aAWOpCdCMRbI6pLQm8r7S2BM3YhTa0SzOYD55k+hJA8oo9J+H+9wLM9oMlZWOX/wJWPTzfDfmZkf7LvCfg==", + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-visually-hidden/-/react-visually-hidden-1.2.3.tgz", + "integrity": "sha512-pzJq12tEaaIhqjbzpCuv/OypJY/BPavOofm+dbab+MHLajy277+1lLm6JFcGgF5eskJ6mquGirhXY2GD/8u8Ug==", "license": "MIT", "dependencies": { - "@radix-ui/react-primitive": "2.1.0" + "@radix-ui/react-primitive": "2.1.3" }, "peerDependencies": { "@types/react": "*", @@ -2109,46 +2216,54 @@ } }, "node_modules/@tailwindcss/node": { - "version": "4.1.4", - "resolved": "https://registry.npmjs.org/@tailwindcss/node/-/node-4.1.4.tgz", - "integrity": "sha512-MT5118zaiO6x6hNA04OWInuAiP1YISXql8Z+/Y8iisV5nuhM8VXlyhRuqc2PEviPszcXI66W44bCIk500Oolhw==", + "version": "4.1.8", + "resolved": "https://registry.npmjs.org/@tailwindcss/node/-/node-4.1.8.tgz", + "integrity": "sha512-OWwBsbC9BFAJelmnNcrKuf+bka2ZxCE2A4Ft53Tkg4uoiE67r/PMEYwCsourC26E+kmxfwE0hVzMdxqeW+xu7Q==", "dev": true, "license": "MIT", "dependencies": { + "@ampproject/remapping": "^2.3.0", "enhanced-resolve": "^5.18.1", "jiti": "^2.4.2", - "lightningcss": "1.29.2", - "tailwindcss": "4.1.4" + "lightningcss": "1.30.1", + "magic-string": "^0.30.17", + "source-map-js": "^1.2.1", + "tailwindcss": "4.1.8" } }, "node_modules/@tailwindcss/oxide": { - "version": "4.1.4", - "resolved": "https://registry.npmjs.org/@tailwindcss/oxide/-/oxide-4.1.4.tgz", - "integrity": "sha512-p5wOpXyOJx7mKh5MXh5oKk+kqcz8T+bA3z/5VWWeQwFrmuBItGwz8Y2CHk/sJ+dNb9B0nYFfn0rj/cKHZyjahQ==", + "version": "4.1.8", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide/-/oxide-4.1.8.tgz", + "integrity": "sha512-d7qvv9PsM5N3VNKhwVUhpK6r4h9wtLkJ6lz9ZY9aeZgrUWk1Z8VPyqyDT9MZlem7GTGseRQHkeB1j3tC7W1P+A==", "dev": true, + "hasInstallScript": true, "license": "MIT", + "dependencies": { + "detect-libc": "^2.0.4", + "tar": "^7.4.3" + }, "engines": { "node": ">= 10" }, "optionalDependencies": { - "@tailwindcss/oxide-android-arm64": "4.1.4", - "@tailwindcss/oxide-darwin-arm64": "4.1.4", - "@tailwindcss/oxide-darwin-x64": "4.1.4", - "@tailwindcss/oxide-freebsd-x64": "4.1.4", - "@tailwindcss/oxide-linux-arm-gnueabihf": "4.1.4", - "@tailwindcss/oxide-linux-arm64-gnu": "4.1.4", - "@tailwindcss/oxide-linux-arm64-musl": "4.1.4", - "@tailwindcss/oxide-linux-x64-gnu": "4.1.4", - "@tailwindcss/oxide-linux-x64-musl": "4.1.4", - "@tailwindcss/oxide-wasm32-wasi": "4.1.4", - "@tailwindcss/oxide-win32-arm64-msvc": "4.1.4", - "@tailwindcss/oxide-win32-x64-msvc": "4.1.4" + "@tailwindcss/oxide-android-arm64": "4.1.8", + "@tailwindcss/oxide-darwin-arm64": "4.1.8", + "@tailwindcss/oxide-darwin-x64": "4.1.8", + "@tailwindcss/oxide-freebsd-x64": "4.1.8", + "@tailwindcss/oxide-linux-arm-gnueabihf": "4.1.8", + "@tailwindcss/oxide-linux-arm64-gnu": "4.1.8", + "@tailwindcss/oxide-linux-arm64-musl": "4.1.8", + "@tailwindcss/oxide-linux-x64-gnu": "4.1.8", + "@tailwindcss/oxide-linux-x64-musl": "4.1.8", + "@tailwindcss/oxide-wasm32-wasi": "4.1.8", + "@tailwindcss/oxide-win32-arm64-msvc": "4.1.8", + "@tailwindcss/oxide-win32-x64-msvc": "4.1.8" } }, "node_modules/@tailwindcss/oxide-android-arm64": { - "version": "4.1.4", - "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-android-arm64/-/oxide-android-arm64-4.1.4.tgz", - "integrity": "sha512-xMMAe/SaCN/vHfQYui3fqaBDEXMu22BVwQ33veLc8ep+DNy7CWN52L+TTG9y1K397w9nkzv+Mw+mZWISiqhmlA==", + "version": "4.1.8", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-android-arm64/-/oxide-android-arm64-4.1.8.tgz", + "integrity": "sha512-Fbz7qni62uKYceWYvUjRqhGfZKwhZDQhlrJKGtnZfuNtHFqa8wmr+Wn74CTWERiW2hn3mN5gTpOoxWKk0jRxjg==", "cpu": [ "arm64" ], @@ -2163,9 +2278,9 @@ } }, "node_modules/@tailwindcss/oxide-darwin-arm64": { - "version": "4.1.4", - "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-darwin-arm64/-/oxide-darwin-arm64-4.1.4.tgz", - "integrity": "sha512-JGRj0SYFuDuAGilWFBlshcexev2hOKfNkoX+0QTksKYq2zgF9VY/vVMq9m8IObYnLna0Xlg+ytCi2FN2rOL0Sg==", + "version": "4.1.8", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-darwin-arm64/-/oxide-darwin-arm64-4.1.8.tgz", + "integrity": "sha512-RdRvedGsT0vwVVDztvyXhKpsU2ark/BjgG0huo4+2BluxdXo8NDgzl77qh0T1nUxmM11eXwR8jA39ibvSTbi7A==", "cpu": [ "arm64" ], @@ -2180,9 +2295,9 @@ } }, "node_modules/@tailwindcss/oxide-darwin-x64": { - "version": "4.1.4", - "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-darwin-x64/-/oxide-darwin-x64-4.1.4.tgz", - "integrity": "sha512-sdDeLNvs3cYeWsEJ4H1DvjOzaGios4QbBTNLVLVs0XQ0V95bffT3+scptzYGPMjm7xv4+qMhCDrkHwhnUySEzA==", + "version": "4.1.8", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-darwin-x64/-/oxide-darwin-x64-4.1.8.tgz", + "integrity": "sha512-t6PgxjEMLp5Ovf7uMb2OFmb3kqzVTPPakWpBIFzppk4JE4ix0yEtbtSjPbU8+PZETpaYMtXvss2Sdkx8Vs4XRw==", "cpu": [ "x64" ], @@ -2197,9 +2312,9 @@ } }, "node_modules/@tailwindcss/oxide-freebsd-x64": { - "version": "4.1.4", - "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-freebsd-x64/-/oxide-freebsd-x64-4.1.4.tgz", - "integrity": "sha512-VHxAqxqdghM83HslPhRsNhHo91McsxRJaEnShJOMu8mHmEj9Ig7ToHJtDukkuLWLzLboh2XSjq/0zO6wgvykNA==", + "version": "4.1.8", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-freebsd-x64/-/oxide-freebsd-x64-4.1.8.tgz", + "integrity": "sha512-g8C8eGEyhHTqwPStSwZNSrOlyx0bhK/V/+zX0Y+n7DoRUzyS8eMbVshVOLJTDDC+Qn9IJnilYbIKzpB9n4aBsg==", "cpu": [ "x64" ], @@ -2214,9 +2329,9 @@ } }, "node_modules/@tailwindcss/oxide-linux-arm-gnueabihf": { - "version": "4.1.4", - "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-linux-arm-gnueabihf/-/oxide-linux-arm-gnueabihf-4.1.4.tgz", - "integrity": "sha512-OTU/m/eV4gQKxy9r5acuesqaymyeSCnsx1cFto/I1WhPmi5HDxX1nkzb8KYBiwkHIGg7CTfo/AcGzoXAJBxLfg==", + "version": "4.1.8", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-linux-arm-gnueabihf/-/oxide-linux-arm-gnueabihf-4.1.8.tgz", + "integrity": "sha512-Jmzr3FA4S2tHhaC6yCjac3rGf7hG9R6Gf2z9i9JFcuyy0u79HfQsh/thifbYTF2ic82KJovKKkIB6Z9TdNhCXQ==", "cpu": [ "arm" ], @@ -2231,9 +2346,9 @@ } }, "node_modules/@tailwindcss/oxide-linux-arm64-gnu": { - "version": "4.1.4", - "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-linux-arm64-gnu/-/oxide-linux-arm64-gnu-4.1.4.tgz", - "integrity": "sha512-hKlLNvbmUC6z5g/J4H+Zx7f7w15whSVImokLPmP6ff1QqTVE+TxUM9PGuNsjHvkvlHUtGTdDnOvGNSEUiXI1Ww==", + "version": "4.1.8", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-linux-arm64-gnu/-/oxide-linux-arm64-gnu-4.1.8.tgz", + "integrity": "sha512-qq7jXtO1+UEtCmCeBBIRDrPFIVI4ilEQ97qgBGdwXAARrUqSn/L9fUrkb1XP/mvVtoVeR2bt/0L77xx53bPZ/Q==", "cpu": [ "arm64" ], @@ -2248,9 +2363,9 @@ } }, "node_modules/@tailwindcss/oxide-linux-arm64-musl": { - "version": "4.1.4", - "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-linux-arm64-musl/-/oxide-linux-arm64-musl-4.1.4.tgz", - "integrity": "sha512-X3As2xhtgPTY/m5edUtddmZ8rCruvBvtxYLMw9OsZdH01L2gS2icsHRwxdU0dMItNfVmrBezueXZCHxVeeb7Aw==", + "version": "4.1.8", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-linux-arm64-musl/-/oxide-linux-arm64-musl-4.1.8.tgz", + "integrity": "sha512-O6b8QesPbJCRshsNApsOIpzKt3ztG35gfX9tEf4arD7mwNinsoCKxkj8TgEE0YRjmjtO3r9FlJnT/ENd9EVefQ==", "cpu": [ "arm64" ], @@ -2265,9 +2380,9 @@ } }, "node_modules/@tailwindcss/oxide-linux-x64-gnu": { - "version": "4.1.4", - "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-linux-x64-gnu/-/oxide-linux-x64-gnu-4.1.4.tgz", - "integrity": "sha512-2VG4DqhGaDSmYIu6C4ua2vSLXnJsb/C9liej7TuSO04NK+JJJgJucDUgmX6sn7Gw3Cs5ZJ9ZLrnI0QRDOjLfNQ==", + "version": "4.1.8", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-linux-x64-gnu/-/oxide-linux-x64-gnu-4.1.8.tgz", + "integrity": "sha512-32iEXX/pXwikshNOGnERAFwFSfiltmijMIAbUhnNyjFr3tmWmMJWQKU2vNcFX0DACSXJ3ZWcSkzNbaKTdngH6g==", "cpu": [ "x64" ], @@ -2282,9 +2397,9 @@ } }, "node_modules/@tailwindcss/oxide-linux-x64-musl": { - "version": "4.1.4", - "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-linux-x64-musl/-/oxide-linux-x64-musl-4.1.4.tgz", - "integrity": "sha512-v+mxVgH2kmur/X5Mdrz9m7TsoVjbdYQT0b4Z+dr+I4RvreCNXyCFELZL/DO0M1RsidZTrm6O1eMnV6zlgEzTMQ==", + "version": "4.1.8", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-linux-x64-musl/-/oxide-linux-x64-musl-4.1.8.tgz", + "integrity": "sha512-s+VSSD+TfZeMEsCaFaHTaY5YNj3Dri8rST09gMvYQKwPphacRG7wbuQ5ZJMIJXN/puxPcg/nU+ucvWguPpvBDg==", "cpu": [ "x64" ], @@ -2299,9 +2414,9 @@ } }, "node_modules/@tailwindcss/oxide-wasm32-wasi": { - "version": "4.1.4", - "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-wasm32-wasi/-/oxide-wasm32-wasi-4.1.4.tgz", - "integrity": "sha512-2TLe9ir+9esCf6Wm+lLWTMbgklIjiF0pbmDnwmhR9MksVOq+e8aP3TSsXySnBDDvTTVd/vKu1aNttEGj3P6l8Q==", + "version": "4.1.8", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-wasm32-wasi/-/oxide-wasm32-wasi-4.1.8.tgz", + "integrity": "sha512-CXBPVFkpDjM67sS1psWohZ6g/2/cd+cq56vPxK4JeawelxwK4YECgl9Y9TjkE2qfF+9/s1tHHJqrC4SS6cVvSg==", "bundleDependencies": [ "@napi-rs/wasm-runtime", "@emnapi/core", @@ -2317,10 +2432,10 @@ "license": "MIT", "optional": true, "dependencies": { - "@emnapi/core": "^1.4.0", - "@emnapi/runtime": "^1.4.0", - "@emnapi/wasi-threads": "^1.0.1", - "@napi-rs/wasm-runtime": "^0.2.8", + "@emnapi/core": "^1.4.3", + "@emnapi/runtime": "^1.4.3", + "@emnapi/wasi-threads": "^1.0.2", + "@napi-rs/wasm-runtime": "^0.2.10", "@tybys/wasm-util": "^0.9.0", "tslib": "^2.8.0" }, @@ -2329,9 +2444,9 @@ } }, "node_modules/@tailwindcss/oxide-win32-arm64-msvc": { - "version": "4.1.4", - "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-win32-arm64-msvc/-/oxide-win32-arm64-msvc-4.1.4.tgz", - "integrity": "sha512-VlnhfilPlO0ltxW9/BgfLI5547PYzqBMPIzRrk4W7uupgCt8z6Trw/tAj6QUtF2om+1MH281Pg+HHUJoLesmng==", + "version": "4.1.8", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-win32-arm64-msvc/-/oxide-win32-arm64-msvc-4.1.8.tgz", + "integrity": "sha512-7GmYk1n28teDHUjPlIx4Z6Z4hHEgvP5ZW2QS9ygnDAdI/myh3HTHjDqtSqgu1BpRoI4OiLx+fThAyA1JePoENA==", "cpu": [ "arm64" ], @@ -2346,9 +2461,9 @@ } }, "node_modules/@tailwindcss/oxide-win32-x64-msvc": { - "version": "4.1.4", - "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-win32-x64-msvc/-/oxide-win32-x64-msvc-4.1.4.tgz", - "integrity": "sha512-+7S63t5zhYjslUGb8NcgLpFXD+Kq1F/zt5Xv5qTv7HaFTG/DHyHD9GA6ieNAxhgyA4IcKa/zy7Xx4Oad2/wuhw==", + "version": "4.1.8", + "resolved": "https://registry.npmjs.org/@tailwindcss/oxide-win32-x64-msvc/-/oxide-win32-x64-msvc-4.1.8.tgz", + "integrity": "sha512-fou+U20j+Jl0EHwK92spoWISON2OBnCazIc038Xj2TdweYV33ZRkS9nwqiUi2d/Wba5xg5UoHfvynnb/UB49cQ==", "cpu": [ "x64" ], @@ -2363,17 +2478,17 @@ } }, "node_modules/@tailwindcss/postcss": { - "version": "4.1.4", - "resolved": "https://registry.npmjs.org/@tailwindcss/postcss/-/postcss-4.1.4.tgz", - "integrity": "sha512-bjV6sqycCEa+AQSt2Kr7wpGF1bOZJ5wsqnLEkqSbM/JEHxx/yhMH8wHmdkPyApF9xhHeMSwnnkDUUMMM/hYnXw==", + "version": "4.1.8", + "resolved": "https://registry.npmjs.org/@tailwindcss/postcss/-/postcss-4.1.8.tgz", + "integrity": "sha512-vB/vlf7rIky+w94aWMw34bWW1ka6g6C3xIOdICKX2GC0VcLtL6fhlLiafF0DVIwa9V6EHz8kbWMkS2s2QvvNlw==", "dev": true, "license": "MIT", "dependencies": { "@alloc/quick-lru": "^5.2.0", - "@tailwindcss/node": "4.1.4", - "@tailwindcss/oxide": "4.1.4", + "@tailwindcss/node": "4.1.8", + "@tailwindcss/oxide": "4.1.8", "postcss": "^8.4.41", - "tailwindcss": "4.1.4" + "tailwindcss": "4.1.8" } }, "node_modules/@tybys/wasm-util": { @@ -2479,9 +2594,9 @@ "license": "MIT" }, "node_modules/@types/node": { - "version": "20.17.32", - "resolved": "https://registry.npmjs.org/@types/node/-/node-20.17.32.tgz", - "integrity": "sha512-zeMXFn8zQ+UkjK4ws0RiOC9EWByyW1CcVmLe+2rQocXRsGEDxUCwPEIVgpsGcLHS/P8JkT0oa3839BRABS0oPw==", + "version": "20.17.57", + "resolved": "https://registry.npmjs.org/@types/node/-/node-20.17.57.tgz", + "integrity": "sha512-f3T4y6VU4fVQDKVqJV4Uppy8c1p/sVvS3peyqxyWnzkqXFJLRU7Y1Bl7rMS1Qe9z0v4M6McY0Fp9yBsgHJUsWQ==", "dev": true, "license": "MIT", "dependencies": { @@ -2496,9 +2611,9 @@ "license": "MIT" }, "node_modules/@types/react": { - "version": "18.3.20", - "resolved": "https://registry.npmjs.org/@types/react/-/react-18.3.20.tgz", - "integrity": "sha512-IPaCZN7PShZK/3t6Q87pfTkRm6oLTd4vztyoj+cbHUF1g3FfVb2tFIL79uCRKEfv16AhqDMBywP2VW3KIZUvcg==", + "version": "18.3.23", + "resolved": "https://registry.npmjs.org/@types/react/-/react-18.3.23.tgz", + "integrity": "sha512-/LDXMQh55EzZQ0uVAZmKKhfENivEvWz6E+EYzh+/MCjMhNsotd+ZHhBGIjFDTi6+fz0OhQQQLbTgdQIxxCsC0w==", "devOptional": true, "license": "MIT", "dependencies": { @@ -2507,9 +2622,9 @@ } }, "node_modules/@types/react-dom": { - "version": "18.3.6", - "resolved": "https://registry.npmjs.org/@types/react-dom/-/react-dom-18.3.6.tgz", - "integrity": "sha512-nf22//wEbKXusP6E9pfOCDwFdHAX4u172eaJI4YkDRQEZiorm6KfYnSC2SWLDMVWUOWPERmJnN0ujeAfTBLvrw==", + "version": "18.3.7", + "resolved": "https://registry.npmjs.org/@types/react-dom/-/react-dom-18.3.7.tgz", + "integrity": "sha512-MEe3UeoENYVFXzoXEWsvcpg6ZvlrFNlOQ7EOsvhI3CfAXwzPfO8Qwuxd40nepsYKqyyVQnTdEfv68q91yLcKrQ==", "devOptional": true, "license": "MIT", "peerDependencies": { @@ -2727,9 +2842,9 @@ "license": "ISC" }, "node_modules/@unrs/resolver-binding-darwin-arm64": { - "version": "1.7.2", - "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-darwin-arm64/-/resolver-binding-darwin-arm64-1.7.2.tgz", - "integrity": "sha512-vxtBno4xvowwNmO/ASL0Y45TpHqmNkAaDtz4Jqb+clmcVSSl8XCG/PNFFkGsXXXS6AMjP+ja/TtNCFFa1QwLRg==", + "version": "1.7.10", + "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-darwin-arm64/-/resolver-binding-darwin-arm64-1.7.10.tgz", + "integrity": "sha512-ABsM3eEiL3yu903G0uxgvGAoIw011XjTzyEk//gGtuVY1PuXP2IJG6novd6DBjm7MaWmRV/CZFY1rWBXSlSVVw==", "cpu": [ "arm64" ], @@ -2741,9 +2856,9 @@ ] }, "node_modules/@unrs/resolver-binding-darwin-x64": { - "version": "1.7.2", - "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-darwin-x64/-/resolver-binding-darwin-x64-1.7.2.tgz", - "integrity": "sha512-qhVa8ozu92C23Hsmv0BF4+5Dyyd5STT1FolV4whNgbY6mj3kA0qsrGPe35zNR3wAN7eFict3s4Rc2dDTPBTuFQ==", + "version": "1.7.10", + "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-darwin-x64/-/resolver-binding-darwin-x64-1.7.10.tgz", + "integrity": "sha512-lGVWy4FQEDo/PuI1VQXaQCY0XUg4xUJilf3fQ8NY4wtsQTm9lbasbUYf3nkoma+O2/do90jQTqkb02S3meyTDg==", "cpu": [ "x64" ], @@ -2755,9 +2870,9 @@ ] }, "node_modules/@unrs/resolver-binding-freebsd-x64": { - "version": "1.7.2", - "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-freebsd-x64/-/resolver-binding-freebsd-x64-1.7.2.tgz", - "integrity": "sha512-zKKdm2uMXqLFX6Ac7K5ElnnG5VIXbDlFWzg4WJ8CGUedJryM5A3cTgHuGMw1+P5ziV8CRhnSEgOnurTI4vpHpg==", + "version": "1.7.10", + "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-freebsd-x64/-/resolver-binding-freebsd-x64-1.7.10.tgz", + "integrity": "sha512-g9XLCHzNGatY79JJNgxrUH6uAAfBDj2NWIlTnqQN5odwGKjyVfFZ5tFL1OxYPcxTHh384TY5lvTtF+fuEZNvBQ==", "cpu": [ "x64" ], @@ -2769,9 +2884,9 @@ ] }, "node_modules/@unrs/resolver-binding-linux-arm-gnueabihf": { - "version": "1.7.2", - "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-arm-gnueabihf/-/resolver-binding-linux-arm-gnueabihf-1.7.2.tgz", - "integrity": "sha512-8N1z1TbPnHH+iDS/42GJ0bMPLiGK+cUqOhNbMKtWJ4oFGzqSJk/zoXFzcQkgtI63qMcUI7wW1tq2usZQSb2jxw==", + "version": "1.7.10", + "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-arm-gnueabihf/-/resolver-binding-linux-arm-gnueabihf-1.7.10.tgz", + "integrity": "sha512-zV0ZMNy50sJFJapsjec8onyL9YREQKT88V8KwMoOA+zki/duFUP0oyTlbax1jGKdh8rQnruvW9VYkovGvdBAsw==", "cpu": [ "arm" ], @@ -2783,9 +2898,9 @@ ] }, "node_modules/@unrs/resolver-binding-linux-arm-musleabihf": { - "version": "1.7.2", - "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-arm-musleabihf/-/resolver-binding-linux-arm-musleabihf-1.7.2.tgz", - "integrity": "sha512-tjYzI9LcAXR9MYd9rO45m1s0B/6bJNuZ6jeOxo1pq1K6OBuRMMmfyvJYval3s9FPPGmrldYA3mi4gWDlWuTFGA==", + "version": "1.7.10", + "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-arm-musleabihf/-/resolver-binding-linux-arm-musleabihf-1.7.10.tgz", + "integrity": "sha512-jQxgb1DIDI7goyrabh4uvyWWBrFRfF+OOnS9SbF15h52g3Qjn/u8zG7wOQ0NjtcSMftzO75TITu9aHuI7FcqQQ==", "cpu": [ "arm" ], @@ -2797,9 +2912,9 @@ ] }, "node_modules/@unrs/resolver-binding-linux-arm64-gnu": { - "version": "1.7.2", - "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-arm64-gnu/-/resolver-binding-linux-arm64-gnu-1.7.2.tgz", - "integrity": "sha512-jon9M7DKRLGZ9VYSkFMflvNqu9hDtOCEnO2QAryFWgT6o6AXU8du56V7YqnaLKr6rAbZBWYsYpikF226v423QA==", + "version": "1.7.10", + "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-arm64-gnu/-/resolver-binding-linux-arm64-gnu-1.7.10.tgz", + "integrity": "sha512-9wVVlO6+aNlm90YWitwSI++HyCyBkzYCwMi7QbuGrTxDFm2pAgtpT0OEliaI7tLS8lAWYuDbzRRCJDgsdm6nwg==", "cpu": [ "arm64" ], @@ -2811,9 +2926,9 @@ ] }, "node_modules/@unrs/resolver-binding-linux-arm64-musl": { - "version": "1.7.2", - "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-arm64-musl/-/resolver-binding-linux-arm64-musl-1.7.2.tgz", - "integrity": "sha512-c8Cg4/h+kQ63pL43wBNaVMmOjXI/X62wQmru51qjfTvI7kmCy5uHTJvK/9LrF0G8Jdx8r34d019P1DVJmhXQpA==", + "version": "1.7.10", + "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-arm64-musl/-/resolver-binding-linux-arm64-musl-1.7.10.tgz", + "integrity": "sha512-FtFweORChdXOes0RAAyTZp6I4PodU2cZiSILAbGaEKDXp378UOumD2vaAkWHNxpsreQUKRxG5O1uq9EoV1NiVQ==", "cpu": [ "arm64" ], @@ -2825,9 +2940,9 @@ ] }, "node_modules/@unrs/resolver-binding-linux-ppc64-gnu": { - "version": "1.7.2", - "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-ppc64-gnu/-/resolver-binding-linux-ppc64-gnu-1.7.2.tgz", - "integrity": "sha512-A+lcwRFyrjeJmv3JJvhz5NbcCkLQL6Mk16kHTNm6/aGNc4FwPHPE4DR9DwuCvCnVHvF5IAd9U4VIs/VvVir5lg==", + "version": "1.7.10", + "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-ppc64-gnu/-/resolver-binding-linux-ppc64-gnu-1.7.10.tgz", + "integrity": "sha512-B+hOjpG2ncCR96a9d9ww1dWVuRVC2NChD0bITgrUhEWBhpdv2o/Mu2l8MsB2fzjdV/ku+twaQhr8iLHBoZafZQ==", "cpu": [ "ppc64" ], @@ -2839,9 +2954,9 @@ ] }, "node_modules/@unrs/resolver-binding-linux-riscv64-gnu": { - "version": "1.7.2", - "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-riscv64-gnu/-/resolver-binding-linux-riscv64-gnu-1.7.2.tgz", - "integrity": "sha512-hQQ4TJQrSQW8JlPm7tRpXN8OCNP9ez7PajJNjRD1ZTHQAy685OYqPrKjfaMw/8LiHCt8AZ74rfUVHP9vn0N69Q==", + "version": "1.7.10", + "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-riscv64-gnu/-/resolver-binding-linux-riscv64-gnu-1.7.10.tgz", + "integrity": "sha512-DS6jFDoQCFsnsdLXlj3z3THakQLBic63B6A0rpQ1kpkyKa3OzEfqhwRNVaywuUuOKP9bX55Jk2uqpvn/hGjKCg==", "cpu": [ "riscv64" ], @@ -2853,9 +2968,9 @@ ] }, "node_modules/@unrs/resolver-binding-linux-riscv64-musl": { - "version": "1.7.2", - "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-riscv64-musl/-/resolver-binding-linux-riscv64-musl-1.7.2.tgz", - "integrity": "sha512-NoAGbiqrxtY8kVooZ24i70CjLDlUFI7nDj3I9y54U94p+3kPxwd2L692YsdLa+cqQ0VoqMWoehDFp21PKRUoIQ==", + "version": "1.7.10", + "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-riscv64-musl/-/resolver-binding-linux-riscv64-musl-1.7.10.tgz", + "integrity": "sha512-A82SB6yEaA8EhIW2r0I7P+k5lg7zPscFnGs1Gna5rfPwoZjeUAGX76T55+DiyTiy08VFKUi79PGCulXnfjDq0g==", "cpu": [ "riscv64" ], @@ -2867,9 +2982,9 @@ ] }, "node_modules/@unrs/resolver-binding-linux-s390x-gnu": { - "version": "1.7.2", - "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-s390x-gnu/-/resolver-binding-linux-s390x-gnu-1.7.2.tgz", - "integrity": "sha512-KaZByo8xuQZbUhhreBTW+yUnOIHUsv04P8lKjQ5otiGoSJ17ISGYArc+4vKdLEpGaLbemGzr4ZeUbYQQsLWFjA==", + "version": "1.7.10", + "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-s390x-gnu/-/resolver-binding-linux-s390x-gnu-1.7.10.tgz", + "integrity": "sha512-J+VmOPH16U69QshCp9WS+Zuiuu9GWTISKchKIhLbS/6JSCEfw2A4N02whv2VmrkXE287xxZbhW1p6xlAXNzwqg==", "cpu": [ "s390x" ], @@ -2881,9 +2996,9 @@ ] }, "node_modules/@unrs/resolver-binding-linux-x64-gnu": { - "version": "1.7.2", - "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-x64-gnu/-/resolver-binding-linux-x64-gnu-1.7.2.tgz", - "integrity": "sha512-dEidzJDubxxhUCBJ/SHSMJD/9q7JkyfBMT77Px1npl4xpg9t0POLvnWywSk66BgZS/b2Hy9Y1yFaoMTFJUe9yg==", + "version": "1.7.10", + "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-x64-gnu/-/resolver-binding-linux-x64-gnu-1.7.10.tgz", + "integrity": "sha512-bYTdDltcB/V3fEqpx8YDwDw8ta9uEg8TUbJOtek6JM42u9ciJ7R/jBjNeAOs+QbyxGDd2d6xkBaGwty1HzOz3Q==", "cpu": [ "x64" ], @@ -2895,9 +3010,9 @@ ] }, "node_modules/@unrs/resolver-binding-linux-x64-musl": { - "version": "1.7.2", - "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-x64-musl/-/resolver-binding-linux-x64-musl-1.7.2.tgz", - "integrity": "sha512-RvP+Ux3wDjmnZDT4XWFfNBRVG0fMsc+yVzNFUqOflnDfZ9OYujv6nkh+GOr+watwrW4wdp6ASfG/e7bkDradsw==", + "version": "1.7.10", + "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-linux-x64-musl/-/resolver-binding-linux-x64-musl-1.7.10.tgz", + "integrity": "sha512-NYZ1GvSuTokJ28lqcjrMTnGMySoo4dVcNK/nsNCKCXT++1zekZtJaE+N+4jc1kR7EV0fc1OhRrOGcSt7FT9t8w==", "cpu": [ "x64" ], @@ -2909,9 +3024,9 @@ ] }, "node_modules/@unrs/resolver-binding-wasm32-wasi": { - "version": "1.7.2", - "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-wasm32-wasi/-/resolver-binding-wasm32-wasi-1.7.2.tgz", - "integrity": "sha512-y797JBmO9IsvXVRCKDXOxjyAE4+CcZpla2GSoBQ33TVb3ILXuFnMrbR/QQZoauBYeOFuu4w3ifWLw52sdHGz6g==", + "version": "1.7.10", + "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-wasm32-wasi/-/resolver-binding-wasm32-wasi-1.7.10.tgz", + "integrity": "sha512-MRjJhTaQzLoX8OtzRBQDJ84OJ8IX1FqpRAUSxp/JtPeak+fyDfhXaEjcA/fhfgrACUnvC+jWC52f/V6MixSKCQ==", "cpu": [ "wasm32" ], @@ -2919,16 +3034,16 @@ "license": "MIT", "optional": true, "dependencies": { - "@napi-rs/wasm-runtime": "^0.2.9" + "@napi-rs/wasm-runtime": "^0.2.10" }, "engines": { "node": ">=14.0.0" } }, "node_modules/@unrs/resolver-binding-win32-arm64-msvc": { - "version": "1.7.2", - "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-win32-arm64-msvc/-/resolver-binding-win32-arm64-msvc-1.7.2.tgz", - "integrity": "sha512-gtYTh4/VREVSLA+gHrfbWxaMO/00y+34htY7XpioBTy56YN2eBjkPrY1ML1Zys89X3RJDKVaogzwxlM1qU7egg==", + "version": "1.7.10", + "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-win32-arm64-msvc/-/resolver-binding-win32-arm64-msvc-1.7.10.tgz", + "integrity": "sha512-Cgw6qhdsfzXJnHb006CzqgaX8mD445x5FGKuueaLeH1ptCxDbzRs8wDm6VieOI7rdbstfYBaFtaYN7zBT5CUPg==", "cpu": [ "arm64" ], @@ -2940,9 +3055,9 @@ ] }, "node_modules/@unrs/resolver-binding-win32-ia32-msvc": { - "version": "1.7.2", - "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-win32-ia32-msvc/-/resolver-binding-win32-ia32-msvc-1.7.2.tgz", - "integrity": "sha512-Ywv20XHvHTDRQs12jd3MY8X5C8KLjDbg/jyaal/QLKx3fAShhJyD4blEANInsjxW3P7isHx1Blt56iUDDJO3jg==", + "version": "1.7.10", + "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-win32-ia32-msvc/-/resolver-binding-win32-ia32-msvc-1.7.10.tgz", + "integrity": "sha512-Z7oECyIT2/HsrWpJ6wi2b+lVbPmWqQHuW5zeatafoRXizk1+2wUl+aSop1PF58XcyBuwPP2YpEUUpMZ8ILV4fA==", "cpu": [ "ia32" ], @@ -2954,9 +3069,9 @@ ] }, "node_modules/@unrs/resolver-binding-win32-x64-msvc": { - "version": "1.7.2", - "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-win32-x64-msvc/-/resolver-binding-win32-x64-msvc-1.7.2.tgz", - "integrity": "sha512-friS8NEQfHaDbkThxopGk+LuE5v3iY0StruifjQEt7SLbA46OnfgMO15sOTkbpJkol6RB+1l1TYPXh0sCddpvA==", + "version": "1.7.10", + "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-win32-x64-msvc/-/resolver-binding-win32-x64-msvc-1.7.10.tgz", + "integrity": "sha512-DGAOo5asNvDsmFgwkb7xsgxNyN0If6XFYwDIC1QlRE7kEYWIMRChtWJyHDf30XmGovDNOs/37krxhnga/nm/4w==", "cpu": [ "x64" ], @@ -3070,9 +3185,9 @@ "license": "Python-2.0" }, "node_modules/aria-hidden": { - "version": "1.2.4", - "resolved": "https://registry.npmjs.org/aria-hidden/-/aria-hidden-1.2.4.tgz", - "integrity": "sha512-y+CcFFwelSXpLZk/7fMB2mUbGtX9lKycf1MWJ7CaTIERyitVlyQx6C+sxcROU2BAJ24OiZyK+8wj2i8AlBoS3A==", + "version": "1.2.6", + "resolved": "https://registry.npmjs.org/aria-hidden/-/aria-hidden-1.2.6.tgz", + "integrity": "sha512-ik3ZgC9dY/lYVVM++OISsaYDeg1tb0VtP5uL3ouh1koGOaUMDPpbFIei4JkFimWUFPn90sbMNMXQAIVOlnYKJA==", "license": "MIT", "dependencies": { "tslib": "^2.0.0" @@ -3109,18 +3224,20 @@ } }, "node_modules/array-includes": { - "version": "3.1.8", - "resolved": "https://registry.npmjs.org/array-includes/-/array-includes-3.1.8.tgz", - "integrity": "sha512-itaWrbYbqpGXkGhZPGUulwnhVf5Hpy1xiCFsGqyIGglbBxmG5vSjxQen3/WGOjPpNEv1RtBLKxbmVXm8HpJStQ==", + "version": "3.1.9", + "resolved": "https://registry.npmjs.org/array-includes/-/array-includes-3.1.9.tgz", + "integrity": "sha512-FmeCCAenzH0KH381SPT5FZmiA/TmpndpcaShhfgEN9eCVjnFBqq3l1xrI42y8+PPLI6hypzou4GXw00WHmPBLQ==", "dev": true, "license": "MIT", "dependencies": { - "call-bind": "^1.0.7", + "call-bind": "^1.0.8", + "call-bound": "^1.0.4", "define-properties": "^1.2.1", - "es-abstract": "^1.23.2", - "es-object-atoms": "^1.0.0", - "get-intrinsic": "^1.2.4", - "is-string": "^1.0.7" + "es-abstract": "^1.24.0", + "es-object-atoms": "^1.1.1", + "get-intrinsic": "^1.3.0", + "is-string": "^1.1.1", + "math-intrinsics": "^1.1.0" }, "engines": { "node": ">= 0.4" @@ -3431,9 +3548,9 @@ } }, "node_modules/caniuse-lite": { - "version": "1.0.30001715", - "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001715.tgz", - "integrity": "sha512-7ptkFGMm2OAOgvZpwgA4yjQ5SQbrNVGdRjzH0pBdy1Fasvcr+KAeECmbCAECzTuDuoX0FCY8KzUxjf9+9kfZEw==", + "version": "1.0.30001721", + "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001721.tgz", + "integrity": "sha512-cOuvmUVtKrtEaoKiO0rSc29jcjwMwX5tOHDy4MgVFEWiUXj4uBMJkwI8MDySkgXidpMiHUcviogAvFi4pA2hDQ==", "funding": [ { "type": "opencollective", @@ -3467,6 +3584,16 @@ "url": "https://github.com/chalk/chalk?sponsor=1" } }, + "node_modules/chownr": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/chownr/-/chownr-3.0.0.tgz", + "integrity": "sha512-+IxzY9BZOQd/XuYPRmrvEVjF/nqj5kgT4kEq7VofrDoM1MxoRjEWkrCC3EtLi59TVawxTAn+orJwFQcrqEN1+g==", + "dev": true, + "license": "BlueOak-1.0.0", + "engines": { + "node": ">=18" + } + }, "node_modules/class-variance-authority": { "version": "0.7.1", "resolved": "https://registry.npmjs.org/class-variance-authority/-/class-variance-authority-0.7.1.tgz", @@ -3819,9 +3946,9 @@ } }, "node_modules/debug": { - "version": "4.4.0", - "resolved": "https://registry.npmjs.org/debug/-/debug-4.4.0.tgz", - "integrity": "sha512-6WTZ/IxCY/T6BALoZHaE4ctp9xm+Z5kY/pzYaCHRFeyVhojxlrm+46y68HA6hr0TcwEssoxNiDEUJQjfPZ/RYA==", + "version": "4.4.1", + "resolved": "https://registry.npmjs.org/debug/-/debug-4.4.1.tgz", + "integrity": "sha512-KcKCqiftBJcZr++7ykoDIEwSa3XWowTfNPo92BYxjXiyYEVrUQh2aLyhxBCwww+heortUFxEJYcRzosstTEBYQ==", "dev": true, "license": "MIT", "dependencies": { @@ -3995,9 +4122,9 @@ } }, "node_modules/es-abstract": { - "version": "1.23.9", - "resolved": "https://registry.npmjs.org/es-abstract/-/es-abstract-1.23.9.tgz", - "integrity": "sha512-py07lI0wjxAC/DcfK1S6G7iANonniZwTISvdPzk9hzeH0IZIshbuuFxLIU96OyF89Yb9hiqWn8M/bY83KY5vzA==", + "version": "1.24.0", + "resolved": "https://registry.npmjs.org/es-abstract/-/es-abstract-1.24.0.tgz", + "integrity": "sha512-WSzPgsdLtTcQwm4CROfS5ju2Wa1QQcVeT37jFjYzdFz1r9ahadC8B8/a4qxJxM+09F18iumCdRmlr96ZYkQvEg==", "dev": true, "license": "MIT", "dependencies": { @@ -4005,18 +4132,18 @@ "arraybuffer.prototype.slice": "^1.0.4", "available-typed-arrays": "^1.0.7", "call-bind": "^1.0.8", - "call-bound": "^1.0.3", + "call-bound": "^1.0.4", "data-view-buffer": "^1.0.2", "data-view-byte-length": "^1.0.2", "data-view-byte-offset": "^1.0.1", "es-define-property": "^1.0.1", "es-errors": "^1.3.0", - "es-object-atoms": "^1.0.0", + "es-object-atoms": "^1.1.1", "es-set-tostringtag": "^2.1.0", "es-to-primitive": "^1.3.0", "function.prototype.name": "^1.1.8", - "get-intrinsic": "^1.2.7", - "get-proto": "^1.0.0", + "get-intrinsic": "^1.3.0", + "get-proto": "^1.0.1", "get-symbol-description": "^1.1.0", "globalthis": "^1.0.4", "gopd": "^1.2.0", @@ -4028,21 +4155,24 @@ "is-array-buffer": "^3.0.5", "is-callable": "^1.2.7", "is-data-view": "^1.0.2", + "is-negative-zero": "^2.0.3", "is-regex": "^1.2.1", + "is-set": "^2.0.3", "is-shared-array-buffer": "^1.0.4", "is-string": "^1.1.1", "is-typed-array": "^1.1.15", - "is-weakref": "^1.1.0", + "is-weakref": "^1.1.1", "math-intrinsics": "^1.1.0", - "object-inspect": "^1.13.3", + "object-inspect": "^1.13.4", "object-keys": "^1.1.1", "object.assign": "^4.1.7", "own-keys": "^1.0.1", - "regexp.prototype.flags": "^1.5.3", + "regexp.prototype.flags": "^1.5.4", "safe-array-concat": "^1.1.3", "safe-push-apply": "^1.0.0", "safe-regex-test": "^1.1.0", "set-proto": "^1.0.0", + "stop-iteration-iterator": "^1.1.0", "string.prototype.trim": "^1.2.10", "string.prototype.trimend": "^1.0.9", "string.prototype.trimstart": "^1.0.8", @@ -4051,7 +4181,7 @@ "typed-array-byte-offset": "^1.0.4", "typed-array-length": "^1.0.7", "unbox-primitive": "^1.1.0", - "which-typed-array": "^1.1.18" + "which-typed-array": "^1.1.19" }, "engines": { "node": ">= 0.4" @@ -4268,6 +4398,41 @@ } } }, + "node_modules/eslint-config-next/node_modules/eslint-import-resolver-typescript": { + "version": "3.10.1", + "resolved": "https://registry.npmjs.org/eslint-import-resolver-typescript/-/eslint-import-resolver-typescript-3.10.1.tgz", + "integrity": "sha512-A1rHYb06zjMGAxdLSkN2fXPBwuSaQ0iO5M/hdyS0Ajj1VBaRp0sPD3dn1FhME3c/JluGFbwSxyCfqdSbtQLAHQ==", + "dev": true, + "license": "ISC", + "dependencies": { + "@nolyfill/is-core-module": "1.0.39", + "debug": "^4.4.0", + "get-tsconfig": "^4.10.0", + "is-bun-module": "^2.0.0", + "stable-hash": "^0.0.5", + "tinyglobby": "^0.2.13", + "unrs-resolver": "^1.6.2" + }, + "engines": { + "node": "^14.18.0 || >=16.0.0" + }, + "funding": { + "url": "https://opencollective.com/eslint-import-resolver-typescript" + }, + "peerDependencies": { + "eslint": "*", + "eslint-plugin-import": "*", + "eslint-plugin-import-x": "*" + }, + "peerDependenciesMeta": { + "eslint-plugin-import": { + "optional": true + }, + "eslint-plugin-import-x": { + "optional": true + } + } + }, "node_modules/eslint-config-prettier": { "version": "8.10.0", "resolved": "https://registry.npmjs.org/eslint-config-prettier/-/eslint-config-prettier-8.10.0.tgz", @@ -4303,41 +4468,6 @@ "ms": "^2.1.1" } }, - "node_modules/eslint-import-resolver-typescript": { - "version": "3.10.1", - "resolved": "https://registry.npmjs.org/eslint-import-resolver-typescript/-/eslint-import-resolver-typescript-3.10.1.tgz", - "integrity": "sha512-A1rHYb06zjMGAxdLSkN2fXPBwuSaQ0iO5M/hdyS0Ajj1VBaRp0sPD3dn1FhME3c/JluGFbwSxyCfqdSbtQLAHQ==", - "dev": true, - "license": "ISC", - "dependencies": { - "@nolyfill/is-core-module": "1.0.39", - "debug": "^4.4.0", - "get-tsconfig": "^4.10.0", - "is-bun-module": "^2.0.0", - "stable-hash": "^0.0.5", - "tinyglobby": "^0.2.13", - "unrs-resolver": "^1.6.2" - }, - "engines": { - "node": "^14.18.0 || >=16.0.0" - }, - "funding": { - "url": "https://opencollective.com/eslint-import-resolver-typescript" - }, - "peerDependencies": { - "eslint": "*", - "eslint-plugin-import": "*", - "eslint-plugin-import-x": "*" - }, - "peerDependenciesMeta": { - "eslint-plugin-import": { - "optional": true - }, - "eslint-plugin-import-x": { - "optional": true - } - } - }, "node_modules/eslint-module-utils": { "version": "2.12.0", "resolved": "https://registry.npmjs.org/eslint-module-utils/-/eslint-module-utils-2.12.0.tgz", @@ -4924,14 +5054,15 @@ } }, "node_modules/form-data": { - "version": "4.0.2", - "resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.2.tgz", - "integrity": "sha512-hGfm/slu0ZabnNt4oaRZ6uREyfCj6P4fT/n6A1rGV+Z0VdGXjfOhVUpkn6qVQONHGIFwmveGXyDs75+nr6FM8w==", + "version": "4.0.3", + "resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.3.tgz", + "integrity": "sha512-qsITQPfmvMOSAdeyZ+12I1c+CKSstAFAwu+97zrnWAbIr5u8wfsExUzCesVLC8NgHuRUqNN4Zy6UPWUTRGslcA==", "license": "MIT", "dependencies": { "asynckit": "^0.4.0", "combined-stream": "^1.0.8", "es-set-tostringtag": "^2.1.0", + "hasown": "^2.0.2", "mime-types": "^2.1.12" }, "engines": { @@ -5103,9 +5234,9 @@ } }, "node_modules/get-tsconfig": { - "version": "4.10.0", - "resolved": "https://registry.npmjs.org/get-tsconfig/-/get-tsconfig-4.10.0.tgz", - "integrity": "sha512-kGzZ3LWWQcGIAmg6iWvXn0ei6WDtV26wzHRMwDSzmAbcXrTEXxHy6IehI6/4eT6VRKyMP1eF1VqwrVUmE/LR7A==", + "version": "4.10.1", + "resolved": "https://registry.npmjs.org/get-tsconfig/-/get-tsconfig-4.10.1.tgz", + "integrity": "sha512-auHyJ4AgMz7vgS8Hp3N6HXSmlMdUyhSUrfBF16w153rxtLIEOE+HGqaBppczZvnHLqQJfiHotCYpNhl0lUROFQ==", "dev": true, "license": "MIT", "dependencies": { @@ -5682,6 +5813,19 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/is-negative-zero": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/is-negative-zero/-/is-negative-zero-2.0.3.tgz", + "integrity": "sha512-5KoIu2Ngpyek75jXodFvnafB6DJgr3u8uuK0LEZJjrU19DrMD3EVERaR8sjz8CCGgpZvxPl9SuE1GMVPFHx1mw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, "node_modules/is-number": { "version": "7.0.0", "resolved": "https://registry.npmjs.org/is-number/-/is-number-7.0.0.tgz", @@ -6058,9 +6202,9 @@ } }, "node_modules/lightningcss": { - "version": "1.29.2", - "resolved": "https://registry.npmjs.org/lightningcss/-/lightningcss-1.29.2.tgz", - "integrity": "sha512-6b6gd/RUXKaw5keVdSEtqFVdzWnU5jMxTUjA2bVcMNPLwSQ08Sv/UodBVtETLCn7k4S1Ibxwh7k68IwLZPgKaA==", + "version": "1.30.1", + "resolved": "https://registry.npmjs.org/lightningcss/-/lightningcss-1.30.1.tgz", + "integrity": "sha512-xi6IyHML+c9+Q3W0S4fCQJOym42pyurFiJUHEcEyHS0CeKzia4yZDEsLlqOFykxOdHpNy0NmvVO31vcSqAxJCg==", "dev": true, "license": "MPL-2.0", "dependencies": { @@ -6074,22 +6218,22 @@ "url": "https://opencollective.com/parcel" }, "optionalDependencies": { - "lightningcss-darwin-arm64": "1.29.2", - "lightningcss-darwin-x64": "1.29.2", - "lightningcss-freebsd-x64": "1.29.2", - "lightningcss-linux-arm-gnueabihf": "1.29.2", - "lightningcss-linux-arm64-gnu": "1.29.2", - "lightningcss-linux-arm64-musl": "1.29.2", - "lightningcss-linux-x64-gnu": "1.29.2", - "lightningcss-linux-x64-musl": "1.29.2", - "lightningcss-win32-arm64-msvc": "1.29.2", - "lightningcss-win32-x64-msvc": "1.29.2" + "lightningcss-darwin-arm64": "1.30.1", + "lightningcss-darwin-x64": "1.30.1", + "lightningcss-freebsd-x64": "1.30.1", + "lightningcss-linux-arm-gnueabihf": "1.30.1", + "lightningcss-linux-arm64-gnu": "1.30.1", + "lightningcss-linux-arm64-musl": "1.30.1", + "lightningcss-linux-x64-gnu": "1.30.1", + "lightningcss-linux-x64-musl": "1.30.1", + "lightningcss-win32-arm64-msvc": "1.30.1", + "lightningcss-win32-x64-msvc": "1.30.1" } }, "node_modules/lightningcss-darwin-arm64": { - "version": "1.29.2", - "resolved": "https://registry.npmjs.org/lightningcss-darwin-arm64/-/lightningcss-darwin-arm64-1.29.2.tgz", - "integrity": "sha512-cK/eMabSViKn/PG8U/a7aCorpeKLMlK0bQeNHmdb7qUnBkNPnL+oV5DjJUo0kqWsJUapZsM4jCfYItbqBDvlcA==", + "version": "1.30.1", + "resolved": "https://registry.npmjs.org/lightningcss-darwin-arm64/-/lightningcss-darwin-arm64-1.30.1.tgz", + "integrity": "sha512-c8JK7hyE65X1MHMN+Viq9n11RRC7hgin3HhYKhrMyaXflk5GVplZ60IxyoVtzILeKr+xAJwg6zK6sjTBJ0FKYQ==", "cpu": [ "arm64" ], @@ -6108,9 +6252,9 @@ } }, "node_modules/lightningcss-darwin-x64": { - "version": "1.29.2", - "resolved": "https://registry.npmjs.org/lightningcss-darwin-x64/-/lightningcss-darwin-x64-1.29.2.tgz", - "integrity": "sha512-j5qYxamyQw4kDXX5hnnCKMf3mLlHvG44f24Qyi2965/Ycz829MYqjrVg2H8BidybHBp9kom4D7DR5VqCKDXS0w==", + "version": "1.30.1", + "resolved": "https://registry.npmjs.org/lightningcss-darwin-x64/-/lightningcss-darwin-x64-1.30.1.tgz", + "integrity": "sha512-k1EvjakfumAQoTfcXUcHQZhSpLlkAuEkdMBsI/ivWw9hL+7FtilQc0Cy3hrx0AAQrVtQAbMI7YjCgYgvn37PzA==", "cpu": [ "x64" ], @@ -6129,9 +6273,9 @@ } }, "node_modules/lightningcss-freebsd-x64": { - "version": "1.29.2", - "resolved": "https://registry.npmjs.org/lightningcss-freebsd-x64/-/lightningcss-freebsd-x64-1.29.2.tgz", - "integrity": "sha512-wDk7M2tM78Ii8ek9YjnY8MjV5f5JN2qNVO+/0BAGZRvXKtQrBC4/cn4ssQIpKIPP44YXw6gFdpUF+Ps+RGsCwg==", + "version": "1.30.1", + "resolved": "https://registry.npmjs.org/lightningcss-freebsd-x64/-/lightningcss-freebsd-x64-1.30.1.tgz", + "integrity": "sha512-kmW6UGCGg2PcyUE59K5r0kWfKPAVy4SltVeut+umLCFoJ53RdCUWxcRDzO1eTaxf/7Q2H7LTquFHPL5R+Gjyig==", "cpu": [ "x64" ], @@ -6150,9 +6294,9 @@ } }, "node_modules/lightningcss-linux-arm-gnueabihf": { - "version": "1.29.2", - "resolved": "https://registry.npmjs.org/lightningcss-linux-arm-gnueabihf/-/lightningcss-linux-arm-gnueabihf-1.29.2.tgz", - "integrity": "sha512-IRUrOrAF2Z+KExdExe3Rz7NSTuuJ2HvCGlMKoquK5pjvo2JY4Rybr+NrKnq0U0hZnx5AnGsuFHjGnNT14w26sg==", + "version": "1.30.1", + "resolved": "https://registry.npmjs.org/lightningcss-linux-arm-gnueabihf/-/lightningcss-linux-arm-gnueabihf-1.30.1.tgz", + "integrity": "sha512-MjxUShl1v8pit+6D/zSPq9S9dQ2NPFSQwGvxBCYaBYLPlCWuPh9/t1MRS8iUaR8i+a6w7aps+B4N0S1TYP/R+Q==", "cpu": [ "arm" ], @@ -6171,9 +6315,9 @@ } }, "node_modules/lightningcss-linux-arm64-gnu": { - "version": "1.29.2", - "resolved": "https://registry.npmjs.org/lightningcss-linux-arm64-gnu/-/lightningcss-linux-arm64-gnu-1.29.2.tgz", - "integrity": "sha512-KKCpOlmhdjvUTX/mBuaKemp0oeDIBBLFiU5Fnqxh1/DZ4JPZi4evEH7TKoSBFOSOV3J7iEmmBaw/8dpiUvRKlQ==", + "version": "1.30.1", + "resolved": "https://registry.npmjs.org/lightningcss-linux-arm64-gnu/-/lightningcss-linux-arm64-gnu-1.30.1.tgz", + "integrity": "sha512-gB72maP8rmrKsnKYy8XUuXi/4OctJiuQjcuqWNlJQ6jZiWqtPvqFziskH3hnajfvKB27ynbVCucKSm2rkQp4Bw==", "cpu": [ "arm64" ], @@ -6192,9 +6336,9 @@ } }, "node_modules/lightningcss-linux-arm64-musl": { - "version": "1.29.2", - "resolved": "https://registry.npmjs.org/lightningcss-linux-arm64-musl/-/lightningcss-linux-arm64-musl-1.29.2.tgz", - "integrity": "sha512-Q64eM1bPlOOUgxFmoPUefqzY1yV3ctFPE6d/Vt7WzLW4rKTv7MyYNky+FWxRpLkNASTnKQUaiMJ87zNODIrrKQ==", + "version": "1.30.1", + "resolved": "https://registry.npmjs.org/lightningcss-linux-arm64-musl/-/lightningcss-linux-arm64-musl-1.30.1.tgz", + "integrity": "sha512-jmUQVx4331m6LIX+0wUhBbmMX7TCfjF5FoOH6SD1CttzuYlGNVpA7QnrmLxrsub43ClTINfGSYyHe2HWeLl5CQ==", "cpu": [ "arm64" ], @@ -6213,9 +6357,9 @@ } }, "node_modules/lightningcss-linux-x64-gnu": { - "version": "1.29.2", - "resolved": "https://registry.npmjs.org/lightningcss-linux-x64-gnu/-/lightningcss-linux-x64-gnu-1.29.2.tgz", - "integrity": "sha512-0v6idDCPG6epLXtBH/RPkHvYx74CVziHo6TMYga8O2EiQApnUPZsbR9nFNrg2cgBzk1AYqEd95TlrsL7nYABQg==", + "version": "1.30.1", + "resolved": "https://registry.npmjs.org/lightningcss-linux-x64-gnu/-/lightningcss-linux-x64-gnu-1.30.1.tgz", + "integrity": "sha512-piWx3z4wN8J8z3+O5kO74+yr6ze/dKmPnI7vLqfSqI8bccaTGY5xiSGVIJBDd5K5BHlvVLpUB3S2YCfelyJ1bw==", "cpu": [ "x64" ], @@ -6234,9 +6378,9 @@ } }, "node_modules/lightningcss-linux-x64-musl": { - "version": "1.29.2", - "resolved": "https://registry.npmjs.org/lightningcss-linux-x64-musl/-/lightningcss-linux-x64-musl-1.29.2.tgz", - "integrity": "sha512-rMpz2yawkgGT8RULc5S4WiZopVMOFWjiItBT7aSfDX4NQav6M44rhn5hjtkKzB+wMTRlLLqxkeYEtQ3dd9696w==", + "version": "1.30.1", + "resolved": "https://registry.npmjs.org/lightningcss-linux-x64-musl/-/lightningcss-linux-x64-musl-1.30.1.tgz", + "integrity": "sha512-rRomAK7eIkL+tHY0YPxbc5Dra2gXlI63HL+v1Pdi1a3sC+tJTcFrHX+E86sulgAXeI7rSzDYhPSeHHjqFhqfeQ==", "cpu": [ "x64" ], @@ -6255,9 +6399,9 @@ } }, "node_modules/lightningcss-win32-arm64-msvc": { - "version": "1.29.2", - "resolved": "https://registry.npmjs.org/lightningcss-win32-arm64-msvc/-/lightningcss-win32-arm64-msvc-1.29.2.tgz", - "integrity": "sha512-nL7zRW6evGQqYVu/bKGK+zShyz8OVzsCotFgc7judbt6wnB2KbiKKJwBE4SGoDBQ1O94RjW4asrCjQL4i8Fhbw==", + "version": "1.30.1", + "resolved": "https://registry.npmjs.org/lightningcss-win32-arm64-msvc/-/lightningcss-win32-arm64-msvc-1.30.1.tgz", + "integrity": "sha512-mSL4rqPi4iXq5YVqzSsJgMVFENoa4nGTT/GjO2c0Yl9OuQfPsIfncvLrEW6RbbB24WtZ3xP/2CCmI3tNkNV4oA==", "cpu": [ "arm64" ], @@ -6276,9 +6420,9 @@ } }, "node_modules/lightningcss-win32-x64-msvc": { - "version": "1.29.2", - "resolved": "https://registry.npmjs.org/lightningcss-win32-x64-msvc/-/lightningcss-win32-x64-msvc-1.29.2.tgz", - "integrity": "sha512-EdIUW3B2vLuHmv7urfzMI/h2fmlnOQBk1xlsDxkN1tCWKjNFjfLhGxYk8C8mzpSfr+A6jFFIi8fU6LbQGsRWjA==", + "version": "1.30.1", + "resolved": "https://registry.npmjs.org/lightningcss-win32-x64-msvc/-/lightningcss-win32-x64-msvc-1.30.1.tgz", + "integrity": "sha512-PVqXh48wh4T53F/1CCu8PIPCxLzWyCnn/9T5W1Jpmdy5h9Cwd+0YQS6/LwhHXSafuc61/xg9Lv5OrCby6a++jg==", "cpu": [ "x64" ], @@ -6518,6 +6662,16 @@ "react": "^16.5.1 || ^17.0.0 || ^18.0.0 || ^19.0.0-rc" } }, + "node_modules/magic-string": { + "version": "0.30.17", + "resolved": "https://registry.npmjs.org/magic-string/-/magic-string-0.30.17.tgz", + "integrity": "sha512-sNPKHvyjVf7gyjwS4xGTaW/mCnF8wnjtifKBEhxfZ7E/S8tQ0rssrwGNn6q8JH/ohItJfSQp9mBtQYuTlH5QnA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@jridgewell/sourcemap-codec": "^1.5.0" + } + }, "node_modules/math-intrinsics": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/math-intrinsics/-/math-intrinsics-1.1.0.tgz", @@ -6658,6 +6812,35 @@ "node": ">=16 || 14 >=14.17" } }, + "node_modules/minizlib": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/minizlib/-/minizlib-3.0.2.tgz", + "integrity": "sha512-oG62iEk+CYt5Xj2YqI5Xi9xWUeZhDI8jjQmC5oThVH5JGCTgIjr7ciJDzC7MBzYd//WvR1OTmP5Q38Q8ShQtVA==", + "dev": true, + "license": "MIT", + "dependencies": { + "minipass": "^7.1.2" + }, + "engines": { + "node": ">= 18" + } + }, + "node_modules/mkdirp": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/mkdirp/-/mkdirp-3.0.1.tgz", + "integrity": "sha512-+NsyUUAZDmo6YVHzL/stxSu3t9YS1iljliy3BSDrXJ/dkn1KYdmtZODGGjLcc9XLgVVpH4KshHB8XmZgMhaBXg==", + "dev": true, + "license": "MIT", + "bin": { + "mkdirp": "dist/cjs/src/bin.js" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, "node_modules/motion-dom": { "version": "11.18.1", "resolved": "https://registry.npmjs.org/motion-dom/-/motion-dom-11.18.1.tgz", @@ -6699,9 +6882,9 @@ } }, "node_modules/napi-postinstall": { - "version": "0.2.2", - "resolved": "https://registry.npmjs.org/napi-postinstall/-/napi-postinstall-0.2.2.tgz", - "integrity": "sha512-Wy1VI/hpKHwy1MsnFxHCJxqFwmmxD0RA/EKPL7e6mfbsY01phM2SZyJnRdU0bLvhu0Quby1DCcAZti3ghdl4/A==", + "version": "0.2.4", + "resolved": "https://registry.npmjs.org/napi-postinstall/-/napi-postinstall-0.2.4.tgz", + "integrity": "sha512-ZEzHJwBhZ8qQSbknHqYcdtQVr8zUgGyM/q6h6qAyhtyVMNrSgDhrC4disf03dYW0e+czXyLnZINnCTEkWy0eJg==", "dev": true, "license": "MIT", "bin": { @@ -6729,12 +6912,12 @@ "license": "MIT" }, "node_modules/next": { - "version": "14.2.28", - "resolved": "https://registry.npmjs.org/next/-/next-14.2.28.tgz", - "integrity": "sha512-QLEIP/kYXynIxtcKB6vNjtWLVs3Y4Sb+EClTC/CSVzdLD1gIuItccpu/n1lhmduffI32iPGEK2cLLxxt28qgYA==", + "version": "14.2.29", + "resolved": "https://registry.npmjs.org/next/-/next-14.2.29.tgz", + "integrity": "sha512-s98mCOMOWLGGpGOfgKSnleXLuegvvH415qtRZXpSp00HeEgdmrxmwL9cgKU+h4XrhB16zEI5d/7BnkS3ATInsA==", "license": "MIT", "dependencies": { - "@next/env": "14.2.28", + "@next/env": "14.2.29", "@swc/helpers": "0.5.5", "busboy": "1.6.0", "caniuse-lite": "^1.0.30001579", @@ -6749,15 +6932,15 @@ "node": ">=18.17.0" }, "optionalDependencies": { - "@next/swc-darwin-arm64": "14.2.28", - "@next/swc-darwin-x64": "14.2.28", - "@next/swc-linux-arm64-gnu": "14.2.28", - "@next/swc-linux-arm64-musl": "14.2.28", - "@next/swc-linux-x64-gnu": "14.2.28", - "@next/swc-linux-x64-musl": "14.2.28", - "@next/swc-win32-arm64-msvc": "14.2.28", - "@next/swc-win32-ia32-msvc": "14.2.28", - "@next/swc-win32-x64-msvc": "14.2.28" + "@next/swc-darwin-arm64": "14.2.29", + "@next/swc-darwin-x64": "14.2.29", + "@next/swc-linux-arm64-gnu": "14.2.29", + "@next/swc-linux-arm64-musl": "14.2.29", + "@next/swc-linux-x64-gnu": "14.2.29", + "@next/swc-linux-x64-musl": "14.2.29", + "@next/swc-win32-arm64-msvc": "14.2.29", + "@next/swc-win32-ia32-msvc": "14.2.29", + "@next/swc-win32-x64-msvc": "14.2.29" }, "peerDependencies": { "@opentelemetry/api": "^1.1.0", @@ -7195,9 +7378,9 @@ } }, "node_modules/postcss": { - "version": "8.5.3", - "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.3.tgz", - "integrity": "sha512-dle9A3yYxlBSrt8Fu+IpjGT8SY8hN0mlaA6GY8t0P5PjIOZemULz/E2Bnm/2dcUOena75OTNkHI76uZBNUUq3A==", + "version": "8.5.4", + "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.4.tgz", + "integrity": "sha512-QSa9EBe+uwlGTFmHsPKokv3B/oEMQZxfqW0QqNCyhpa6mB1afzulwn8hihglqAb2pOw+BJgNlmXQ8la2VeHB7w==", "dev": true, "funding": [ { @@ -7215,7 +7398,7 @@ ], "license": "MIT", "dependencies": { - "nanoid": "^3.3.8", + "nanoid": "^3.3.11", "picocolors": "^1.1.1", "source-map-js": "^1.2.1" }, @@ -7311,57 +7494,58 @@ "license": "MIT" }, "node_modules/radix-ui": { - "version": "1.3.4", - "resolved": "https://registry.npmjs.org/radix-ui/-/radix-ui-1.3.4.tgz", - "integrity": "sha512-uHJD4yRGjxbEWhkVU+w9d8d+X6HUlmbesHGsE9tRWKX62FqDD3Z3hfEtVS9W+DpZAPvKSCLfz03O7un8xZT3pg==", + "version": "1.4.2", + "resolved": "https://registry.npmjs.org/radix-ui/-/radix-ui-1.4.2.tgz", + "integrity": "sha512-fT/3YFPJzf2WUpqDoQi005GS8EpCi+53VhcLaHUj5fwkPYiZAjk1mSxFvbMA8Uq71L03n+WysuYC+mlKkXxt/Q==", "license": "MIT", "dependencies": { "@radix-ui/primitive": "1.1.2", - "@radix-ui/react-accessible-icon": "1.1.4", - "@radix-ui/react-accordion": "1.2.8", - "@radix-ui/react-alert-dialog": "1.1.11", - "@radix-ui/react-arrow": "1.1.4", - "@radix-ui/react-aspect-ratio": "1.1.4", - "@radix-ui/react-avatar": "1.1.7", - "@radix-ui/react-checkbox": "1.2.3", - "@radix-ui/react-collapsible": "1.1.8", - "@radix-ui/react-collection": "1.1.4", + "@radix-ui/react-accessible-icon": "1.1.7", + "@radix-ui/react-accordion": "1.2.11", + "@radix-ui/react-alert-dialog": "1.1.14", + "@radix-ui/react-arrow": "1.1.7", + "@radix-ui/react-aspect-ratio": "1.1.7", + "@radix-ui/react-avatar": "1.1.10", + "@radix-ui/react-checkbox": "1.3.2", + "@radix-ui/react-collapsible": "1.1.11", + "@radix-ui/react-collection": "1.1.7", "@radix-ui/react-compose-refs": "1.1.2", "@radix-ui/react-context": "1.1.2", - "@radix-ui/react-context-menu": "2.2.12", - "@radix-ui/react-dialog": "1.1.11", + "@radix-ui/react-context-menu": "2.2.15", + "@radix-ui/react-dialog": "1.1.14", "@radix-ui/react-direction": "1.1.1", - "@radix-ui/react-dismissable-layer": "1.1.7", - "@radix-ui/react-dropdown-menu": "2.1.12", + "@radix-ui/react-dismissable-layer": "1.1.10", + "@radix-ui/react-dropdown-menu": "2.1.15", "@radix-ui/react-focus-guards": "1.1.2", - "@radix-ui/react-focus-scope": "1.1.4", - "@radix-ui/react-form": "0.1.4", - "@radix-ui/react-hover-card": "1.1.11", - "@radix-ui/react-label": "2.1.4", - "@radix-ui/react-menu": "2.1.12", - "@radix-ui/react-menubar": "1.1.12", - "@radix-ui/react-navigation-menu": "1.2.10", - "@radix-ui/react-one-time-password-field": "0.1.4", - "@radix-ui/react-popover": "1.1.11", - "@radix-ui/react-popper": "1.2.4", - "@radix-ui/react-portal": "1.1.6", + "@radix-ui/react-focus-scope": "1.1.7", + "@radix-ui/react-form": "0.1.7", + "@radix-ui/react-hover-card": "1.1.14", + "@radix-ui/react-label": "2.1.7", + "@radix-ui/react-menu": "2.1.15", + "@radix-ui/react-menubar": "1.1.15", + "@radix-ui/react-navigation-menu": "1.2.13", + "@radix-ui/react-one-time-password-field": "0.1.7", + "@radix-ui/react-password-toggle-field": "0.1.2", + "@radix-ui/react-popover": "1.1.14", + "@radix-ui/react-popper": "1.2.7", + "@radix-ui/react-portal": "1.1.9", "@radix-ui/react-presence": "1.1.4", - "@radix-ui/react-primitive": "2.1.0", - "@radix-ui/react-progress": "1.1.4", - "@radix-ui/react-radio-group": "1.3.4", - "@radix-ui/react-roving-focus": "1.1.7", - "@radix-ui/react-scroll-area": "1.2.6", - "@radix-ui/react-select": "2.2.2", - "@radix-ui/react-separator": "1.1.4", - "@radix-ui/react-slider": "1.3.2", - "@radix-ui/react-slot": "1.2.0", - "@radix-ui/react-switch": "1.2.2", - "@radix-ui/react-tabs": "1.1.9", - "@radix-ui/react-toast": "1.2.11", - "@radix-ui/react-toggle": "1.1.6", - "@radix-ui/react-toggle-group": "1.1.7", - "@radix-ui/react-toolbar": "1.1.7", - "@radix-ui/react-tooltip": "1.2.4", + "@radix-ui/react-primitive": "2.1.3", + "@radix-ui/react-progress": "1.1.7", + "@radix-ui/react-radio-group": "1.3.7", + "@radix-ui/react-roving-focus": "1.1.10", + "@radix-ui/react-scroll-area": "1.2.9", + "@radix-ui/react-select": "2.2.5", + "@radix-ui/react-separator": "1.1.7", + "@radix-ui/react-slider": "1.3.5", + "@radix-ui/react-slot": "1.2.3", + "@radix-ui/react-switch": "1.2.5", + "@radix-ui/react-tabs": "1.1.12", + "@radix-ui/react-toast": "1.2.14", + "@radix-ui/react-toggle": "1.1.9", + "@radix-ui/react-toggle-group": "1.1.10", + "@radix-ui/react-toolbar": "1.1.10", + "@radix-ui/react-tooltip": "1.2.7", "@radix-ui/react-use-callback-ref": "1.1.1", "@radix-ui/react-use-controllable-state": "1.2.2", "@radix-ui/react-use-effect-event": "0.0.2", @@ -7369,7 +7553,7 @@ "@radix-ui/react-use-is-hydrated": "0.1.0", "@radix-ui/react-use-layout-effect": "1.1.1", "@radix-ui/react-use-size": "1.1.1", - "@radix-ui/react-visually-hidden": "1.2.0" + "@radix-ui/react-visually-hidden": "1.2.3" }, "peerDependencies": { "@types/react": "*", @@ -7412,9 +7596,9 @@ } }, "node_modules/react-hook-form": { - "version": "7.56.1", - "resolved": "https://registry.npmjs.org/react-hook-form/-/react-hook-form-7.56.1.tgz", - "integrity": "sha512-qWAVokhSpshhcEuQDSANHx3jiAEFzu2HAaaQIzi/r9FNPm1ioAvuJSD4EuZzWd7Al7nTRKcKPnBKO7sRn+zavQ==", + "version": "7.57.0", + "resolved": "https://registry.npmjs.org/react-hook-form/-/react-hook-form-7.57.0.tgz", + "integrity": "sha512-RbEks3+cbvTP84l/VXGUZ+JMrKOS8ykQCRYdm5aYsxnDquL0vspsyNhGRO7pcH6hsZqWlPOjLye7rJqdtdAmlg==", "license": "MIT", "engines": { "node": ">=18.0.0" @@ -7434,9 +7618,9 @@ "license": "MIT" }, "node_modules/react-remove-scroll": { - "version": "2.6.3", - "resolved": "https://registry.npmjs.org/react-remove-scroll/-/react-remove-scroll-2.6.3.tgz", - "integrity": "sha512-pnAi91oOk8g8ABQKGF5/M9qxmmOPxaAnopyTHYfqYEwJhyFrbbBtHuSgtKEoH0jpcxx5o3hXqH1mNd9/Oi+8iQ==", + "version": "2.7.1", + "resolved": "https://registry.npmjs.org/react-remove-scroll/-/react-remove-scroll-2.7.1.tgz", + "integrity": "sha512-HpMh8+oahmIdOuS5aFKKY6Pyog+FNaZV/XyJOq7b4YFwsFHe5yYfdbIalI4k3vU2nSDql7YskmUseHsRrJqIPA==", "license": "MIT", "dependencies": { "react-remove-scroll-bar": "^2.3.7", @@ -7600,12 +7784,6 @@ "url": "https://github.com/sponsors/ljharb" } }, - "node_modules/regenerator-runtime": { - "version": "0.14.1", - "resolved": "https://registry.npmjs.org/regenerator-runtime/-/regenerator-runtime-0.14.1.tgz", - "integrity": "sha512-dYnhHh0nJoMfnkZs6GmmhFknAGRrLznOu5nc9ML+EJxGvrx6H7teuevqVqCuPcPK//3eDrrjQhehXVx9cnkGdw==", - "license": "MIT" - }, "node_modules/regexp.prototype.flags": { "version": "1.5.4", "resolved": "https://registry.npmjs.org/regexp.prototype.flags/-/regexp.prototype.flags-1.5.4.tgz", @@ -7870,9 +8048,9 @@ "license": "MIT" }, "node_modules/semver": { - "version": "7.7.1", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.7.1.tgz", - "integrity": "sha512-hlq8tAfn0m/61p4BVRcPzIGr6LKiMwo4VM6dGi6pt4qcRkmNzTcWq6eCEjEh+qXjkMDvPlOFFSGwQjoEa6gyMA==", + "version": "7.7.2", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.7.2.tgz", + "integrity": "sha512-RF0Fw+rO5AMf9MAyaRXI4AV0Ulj5lMHqVxxdSgiVbixSCXoEmmX/jk0CuJw4+3SqroYO9VoUh+HcuJivvtJemA==", "dev": true, "license": "ISC", "bin": { @@ -8099,6 +8277,20 @@ "dev": true, "license": "MIT" }, + "node_modules/stop-iteration-iterator": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/stop-iteration-iterator/-/stop-iteration-iterator-1.1.0.tgz", + "integrity": "sha512-eLoXW/DHyl62zxY4SCaIgnRhuMr6ri4juEYARS8E6sCEqzKpOiE521Ucofdx+KnDZl5xmvGYaaKCk5FEOxJCoQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "internal-slot": "^1.1.0" + }, + "engines": { + "node": ">= 0.4" + } + }, "node_modules/streamsearch": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/streamsearch/-/streamsearch-1.1.0.tgz", @@ -8433,9 +8625,9 @@ } }, "node_modules/tailwindcss": { - "version": "4.1.4", - "resolved": "https://registry.npmjs.org/tailwindcss/-/tailwindcss-4.1.4.tgz", - "integrity": "sha512-1ZIUqtPITFbv/DxRmDr5/agPqJwF69d24m9qmM1939TJehgY539CtzeZRjbLt5G6fSy/7YqqYsfvoTEw9xUI2A==", + "version": "4.1.8", + "resolved": "https://registry.npmjs.org/tailwindcss/-/tailwindcss-4.1.8.tgz", + "integrity": "sha512-kjeW8gjdxasbmFKpVGrGd5T4i40mV5J2Rasw48QARfYeQ8YS9x02ON9SFWax3Qf616rt4Cp3nVNIj6Hd1mP3og==", "dev": true, "license": "MIT" }, @@ -8450,15 +8642,33 @@ } }, "node_modules/tapable": { - "version": "2.2.1", - "resolved": "https://registry.npmjs.org/tapable/-/tapable-2.2.1.tgz", - "integrity": "sha512-GNzQvQTOIP6RyTfE2Qxb8ZVlNmw0n88vp1szwWRimP02mnTsx3Wtn5qRdqY9w2XduFNUgvOwhNnQsjwCp+kqaQ==", + "version": "2.2.2", + "resolved": "https://registry.npmjs.org/tapable/-/tapable-2.2.2.tgz", + "integrity": "sha512-Re10+NauLTMCudc7T5WLFLAwDhQ0JWdrMK+9B2M8zR5hRExKmsRDCBA7/aV/pNJFltmBFO5BAMlQFi/vq3nKOg==", "dev": true, "license": "MIT", "engines": { "node": ">=6" } }, + "node_modules/tar": { + "version": "7.4.3", + "resolved": "https://registry.npmjs.org/tar/-/tar-7.4.3.tgz", + "integrity": "sha512-5S7Va8hKfV7W5U6g3aYxXmlPoZVAwUMy9AOKyF2fVuZa2UD3qZjg578OrLRt8PcNN1PleVaL/5/yYATNL0ICUw==", + "dev": true, + "license": "ISC", + "dependencies": { + "@isaacs/fs-minipass": "^4.0.0", + "chownr": "^3.0.0", + "minipass": "^7.1.2", + "minizlib": "^3.0.1", + "mkdirp": "^3.0.1", + "yallist": "^5.0.0" + }, + "engines": { + "node": ">=18" + } + }, "node_modules/text-table": { "version": "0.2.0", "resolved": "https://registry.npmjs.org/text-table/-/text-table-0.2.0.tgz", @@ -8479,9 +8689,9 @@ "license": "MIT" }, "node_modules/tinyglobby": { - "version": "0.2.13", - "resolved": "https://registry.npmjs.org/tinyglobby/-/tinyglobby-0.2.13.tgz", - "integrity": "sha512-mEwzpUgrLySlveBwEVDMKk5B57bhLPYovRfPAXD5gA/98Opn0rCDj3GtLwFvCvH5RK9uPCExUROW5NjDwvqkxw==", + "version": "0.2.14", + "resolved": "https://registry.npmjs.org/tinyglobby/-/tinyglobby-0.2.14.tgz", + "integrity": "sha512-tX5e7OM1HnYr2+a2C/4V0htOcSQcoSTH9KgJnVvNm5zm/cyEWKJ7j7YutsH9CxMdtOkkLFy2AHrMci9IM8IPZQ==", "dev": true, "license": "MIT", "dependencies": { @@ -8496,9 +8706,9 @@ } }, "node_modules/tinyglobby/node_modules/fdir": { - "version": "6.4.4", - "resolved": "https://registry.npmjs.org/fdir/-/fdir-6.4.4.tgz", - "integrity": "sha512-1NZP+GK4GfuAv3PqKvxQRDMjdSRZjnkq7KfhlNrCNNlZ0ygQFpebfrnfnq/W7fpUnAv9aGWmY1zKx7FYL3gwhg==", + "version": "6.4.5", + "resolved": "https://registry.npmjs.org/fdir/-/fdir-6.4.5.tgz", + "integrity": "sha512-4BG7puHpVsIYxZUbiUE3RqGloLaSSwzYie5jvasC4LWuBWzZawynvYouhjbQKw2JuIGYdm0DzIxl8iVidKlUEw==", "dev": true, "license": "MIT", "peerDependencies": { @@ -8579,9 +8789,9 @@ "license": "0BSD" }, "node_modules/tw-animate-css": { - "version": "1.2.8", - "resolved": "https://registry.npmjs.org/tw-animate-css/-/tw-animate-css-1.2.8.tgz", - "integrity": "sha512-AxSnYRvyFnAiZCUndS3zQZhNfV/B77ZhJ+O7d3K6wfg/jKJY+yv6ahuyXwnyaYA9UdLqnpCwhTRv9pPTBnPR2g==", + "version": "1.3.4", + "resolved": "https://registry.npmjs.org/tw-animate-css/-/tw-animate-css-1.3.4.tgz", + "integrity": "sha512-dd1Ht6/YQHcNbq0znIT6dG8uhO7Ce+VIIhZUhjsryXsMPJQz3bZg7Q2eNzLwipb25bRZslGb2myio5mScd1TFg==", "dev": true, "license": "MIT", "funding": { @@ -8742,9 +8952,9 @@ "license": "MIT" }, "node_modules/unrs-resolver": { - "version": "1.7.2", - "resolved": "https://registry.npmjs.org/unrs-resolver/-/unrs-resolver-1.7.2.tgz", - "integrity": "sha512-BBKpaylOW8KbHsu378Zky/dGh4ckT/4NW/0SHRABdqRLcQJ2dAOjDo9g97p04sWflm0kqPqpUatxReNV/dqI5A==", + "version": "1.7.10", + "resolved": "https://registry.npmjs.org/unrs-resolver/-/unrs-resolver-1.7.10.tgz", + "integrity": "sha512-CJEMJcz6vuwRK6xxWc+uf8AGi0OyfoVtHs5mExtNecS0HZq3a3Br1JC/InwwTn6uy+qkAdAdK+nJUYO9FPtgZw==", "dev": true, "hasInstallScript": true, "license": "MIT", @@ -8752,26 +8962,26 @@ "napi-postinstall": "^0.2.2" }, "funding": { - "url": "https://github.com/sponsors/JounQin" + "url": "https://opencollective.com/unrs-resolver" }, "optionalDependencies": { - "@unrs/resolver-binding-darwin-arm64": "1.7.2", - "@unrs/resolver-binding-darwin-x64": "1.7.2", - "@unrs/resolver-binding-freebsd-x64": "1.7.2", - "@unrs/resolver-binding-linux-arm-gnueabihf": "1.7.2", - "@unrs/resolver-binding-linux-arm-musleabihf": "1.7.2", - "@unrs/resolver-binding-linux-arm64-gnu": "1.7.2", - "@unrs/resolver-binding-linux-arm64-musl": "1.7.2", - "@unrs/resolver-binding-linux-ppc64-gnu": "1.7.2", - "@unrs/resolver-binding-linux-riscv64-gnu": "1.7.2", - "@unrs/resolver-binding-linux-riscv64-musl": "1.7.2", - "@unrs/resolver-binding-linux-s390x-gnu": "1.7.2", - "@unrs/resolver-binding-linux-x64-gnu": "1.7.2", - "@unrs/resolver-binding-linux-x64-musl": "1.7.2", - "@unrs/resolver-binding-wasm32-wasi": "1.7.2", - "@unrs/resolver-binding-win32-arm64-msvc": "1.7.2", - "@unrs/resolver-binding-win32-ia32-msvc": "1.7.2", - "@unrs/resolver-binding-win32-x64-msvc": "1.7.2" + "@unrs/resolver-binding-darwin-arm64": "1.7.10", + "@unrs/resolver-binding-darwin-x64": "1.7.10", + "@unrs/resolver-binding-freebsd-x64": "1.7.10", + "@unrs/resolver-binding-linux-arm-gnueabihf": "1.7.10", + "@unrs/resolver-binding-linux-arm-musleabihf": "1.7.10", + "@unrs/resolver-binding-linux-arm64-gnu": "1.7.10", + "@unrs/resolver-binding-linux-arm64-musl": "1.7.10", + "@unrs/resolver-binding-linux-ppc64-gnu": "1.7.10", + "@unrs/resolver-binding-linux-riscv64-gnu": "1.7.10", + "@unrs/resolver-binding-linux-riscv64-musl": "1.7.10", + "@unrs/resolver-binding-linux-s390x-gnu": "1.7.10", + "@unrs/resolver-binding-linux-x64-gnu": "1.7.10", + "@unrs/resolver-binding-linux-x64-musl": "1.7.10", + "@unrs/resolver-binding-wasm32-wasi": "1.7.10", + "@unrs/resolver-binding-win32-arm64-msvc": "1.7.10", + "@unrs/resolver-binding-win32-ia32-msvc": "1.7.10", + "@unrs/resolver-binding-win32-x64-msvc": "1.7.10" } }, "node_modules/uri-js": { @@ -9091,6 +9301,16 @@ "dev": true, "license": "ISC" }, + "node_modules/yallist": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/yallist/-/yallist-5.0.0.tgz", + "integrity": "sha512-YgvUTfwqyc7UXVMrB+SImsVYSmTS8X/tSrtdNZMImM+n7+QTriRXyXim0mBrTXNeqzVF0KWGgHPeiyViFFrNDw==", + "dev": true, + "license": "BlueOak-1.0.0", + "engines": { + "node": ">=18" + } + }, "node_modules/yaml": { "version": "2.3.1", "resolved": "https://registry.npmjs.org/yaml/-/yaml-2.3.1.tgz", @@ -9115,18 +9335,18 @@ } }, "node_modules/zod": { - "version": "3.24.3", - "resolved": "https://registry.npmjs.org/zod/-/zod-3.24.3.tgz", - "integrity": "sha512-HhY1oqzWCQWuUqvBFnsyrtZRhyPeR7SUGv+C4+MsisMuVfSPx8HpwWqH8tRahSlt6M3PiFAcoeFhZAqIXTxoSg==", + "version": "3.25.51", + "resolved": "https://registry.npmjs.org/zod/-/zod-3.25.51.tgz", + "integrity": "sha512-TQSnBldh+XSGL+opiSIq0575wvDPqu09AqWe1F7JhUMKY+M91/aGlK4MhpVNO7MgYfHcVCB1ffwAUTJzllKJqg==", "license": "MIT", "funding": { "url": "https://github.com/sponsors/colinhacks" } }, "node_modules/zustand": { - "version": "5.0.3", - "resolved": "https://registry.npmjs.org/zustand/-/zustand-5.0.3.tgz", - "integrity": "sha512-14fwWQtU3pH4dE0dOpdMiWjddcH+QzKIgk1cl8epwSE7yag43k/AD/m4L6+K7DytAOr9gGBe3/EXj9g7cdostg==", + "version": "5.0.5", + "resolved": "https://registry.npmjs.org/zustand/-/zustand-5.0.5.tgz", + "integrity": "sha512-mILtRfKW9xM47hqxGIxCv12gXusoY/xTSHBYApXozR0HmQv299whhBeeAcRy+KrPPybzosvJBCOmVjq6x12fCg==", "license": "MIT", "engines": { "node": ">=12.20.0" diff --git a/frontend/tsconfig.json b/frontend/tsconfig.json index 7b28589..f48e7ee 100644 --- a/frontend/tsconfig.json +++ b/frontend/tsconfig.json @@ -1,6 +1,10 @@ { "compilerOptions": { - "lib": ["dom", "dom.iterable", "esnext"], + "lib": [ + "dom", + "dom.iterable", + "esnext" + ], "allowJs": true, "skipLibCheck": true, "strict": true, @@ -18,9 +22,19 @@ } ], "paths": { - "@/*": ["./src/*"] - } + "@/*": [ + "./src/*" + ] + }, + "target": "ES2017" }, - "include": ["next-env.d.ts", "**/*.ts", "**/*.tsx", ".next/types/**/*.ts"], - "exclude": ["node_modules"] + "include": [ + "next-env.d.ts", + "**/*.ts", + "**/*.tsx", + ".next/types/**/*.ts" + ], + "exclude": [ + "node_modules" + ] } From 0d5a574e1de30a4007c47f1f23378e40798218b3 Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Mon, 9 Jun 2025 18:55:08 +0300 Subject: [PATCH 61/74] update display experiments page --- .../src/app/(protected)/experiments/api.ts | 105 +------ .../experiments/components/ExperimentCard.tsx | 159 ++++++++-- .../components/cards/createBayesABCard.tsx | 64 ---- .../components/cards/createCMABCard.tsx | 75 ----- .../components/cards/createMABCard.tsx | 128 -------- .../src/app/(protected)/experiments/page.tsx | 26 +- .../experiments/store/useExperimentStore.ts | 294 ++++++------------ .../src/app/(protected)/experiments/types.ts | 139 ++------- 8 files changed, 265 insertions(+), 725 deletions(-) delete mode 100644 frontend/src/app/(protected)/experiments/components/cards/createBayesABCard.tsx delete mode 100644 frontend/src/app/(protected)/experiments/components/cards/createCMABCard.tsx delete mode 100644 frontend/src/app/(protected)/experiments/components/cards/createMABCard.tsx diff --git a/frontend/src/app/(protected)/experiments/api.ts b/frontend/src/app/(protected)/experiments/api.ts index 9b2136c..32d8fd6 100644 --- a/frontend/src/app/(protected)/experiments/api.ts +++ b/frontend/src/app/(protected)/experiments/api.ts @@ -1,43 +1,17 @@ import api from "@/utils/api"; -import { ExperimentState, MABBeta, MABNormal, CMAB, BayesianAB } from "./types"; -import { - isMABExperimentStateBeta, - isMABExperimentStateNormal, - isCMABExperimentState, - isBayesianABState, -} from "./store/useExperimentStore"; +import { ExperimentState, NewExperimentState } from "./types"; + const createNewExperiment = async ({ experimentData, token, }: { - experimentData: ExperimentState; + experimentData: NewExperimentState; token: string | null; }) => { - const getEndpointAndData = ( - data: ExperimentState - ): { - endpoint: string; - } => { - - if (isMABExperimentStateBeta(data) || isMABExperimentStateNormal(data)) { - return { endpoint: "/mab/" }; - } - - if (isCMABExperimentState(data)) { - return { endpoint: "/contextual_mab/" }; - } - - if (isBayesianABState(data)) { - return { endpoint: "/bayes_ab/" }; - } - - throw new Error("Invalid experiment type"); - }; try { - const { endpoint } = getEndpointAndData(experimentData); - const response = await api.post(endpoint, experimentData, { + const response = await api.post("/experiment", experimentData, { headers: { Authorization: `Bearer ${token}`, }, @@ -51,20 +25,14 @@ const createNewExperiment = async ({ } }; -const getAllMABExperiments = async (token: string | null) => { +const getExperimentsByType = async (token: string | null, exp_type: string) => { try { - const response = await api.get("/mab/", { + const response = await api.get(`/experiment/type/${exp_type}`, { headers: { Authorization: `Bearer ${token}`, }, }); - const convertedData = response.data.map( - (experiment: MABBeta | MABNormal) => ({ - ...experiment, - methodType: "mab", - }) - ); - return convertedData; + return response.data as ExperimentState[]; } catch (error: unknown) { if (error instanceof Error) { throw new Error(`Error fetching all experiments: ${error.message}`); @@ -74,60 +42,14 @@ const getAllMABExperiments = async (token: string | null) => { } }; -const getAllCMABExperiments = async (token: string | null) => { +const getExperimentById = async (token: string | null, id: number) => { try { - const response = await api.get("/contextual_mab/", { + const response = await api.get(`/experiment/${id}/`, { headers: { Authorization: `Bearer ${token}`, }, }); - const convertedData = response.data.map((experiment: CMAB) => ({ - ...experiment, - methodType: "cmab", - })); - return convertedData; - } catch (error: unknown) { - if (error instanceof Error) { - throw new Error(`Error fetching all experiments: ${error.message}`); - } else { - throw new Error("Error fetching all experiments"); - } - } -}; - -const getAllBayesianABExperiments = async (token: string | null) => { - try { - const response = await api.get("/bayes_ab/", { - headers: { - Authorization: `Bearer ${token}`, - }, - }); - const convertedData = response.data.map((experiment: BayesianAB) => ({ - ...experiment, - methodType: "bayes_ab", - })); - return convertedData; - } catch (error: unknown) { - if (error instanceof Error) { - throw new Error(`Error fetching all experiments: ${error.message}`); - } else { - throw new Error("Error fetching all experiments"); - } - } -}; - -const getMABExperimentById = async (token: string | null, id: number) => { - try { - const response = await api.get(`/mab/${id}/`, { - headers: { - Authorization: `Bearer ${token}`, - }, - }); - const convertedData = { - ...response.data, - methodType: "mab", - }; - return convertedData; + return response.data as ExperimentState; } catch (error: unknown) { if (error instanceof Error) { throw new Error(`Error fetching experiment: ${error.message}`); @@ -136,10 +58,9 @@ const getMABExperimentById = async (token: string | null, id: number) => { } } }; + export { createNewExperiment, - getAllMABExperiments, - getAllCMABExperiments, - getAllBayesianABExperiments, - getMABExperimentById, + getExperimentsByType, + getExperimentById, }; diff --git a/frontend/src/app/(protected)/experiments/components/ExperimentCard.tsx b/frontend/src/app/(protected)/experiments/components/ExperimentCard.tsx index 3ff2749..de906e3 100644 --- a/frontend/src/app/(protected)/experiments/components/ExperimentCard.tsx +++ b/frontend/src/app/(protected)/experiments/components/ExperimentCard.tsx @@ -1,44 +1,139 @@ +import { ExperimentState, Arm } from "../types"; import { Card, CardContent, + CardHeader, CardTitle, CardDescription, } from "@/components/ui/card"; -import { MABBeta, MABNormal, CMAB, BayesianAB, MethodType } from "../types"; -import { MABBetaCards, MABNormalCards } from "./cards/createMABCard"; -import { CMABCards } from "./cards/createCMABCard"; -import { BayesianABCards } from "./cards/createBayesABCard"; +import { Progress } from "@/components/ui/progress"; +import { useRouter } from "next/navigation"; +import { isMABExperimentStateBeta } from "../store/useExperimentStore"; -export default function ExperimentCards({ +const calculateDaysAgo = (dateString: string) => { + const date = new Date(dateString); + const now = new Date(); + const diffTime = Math.abs(now.getTime() - date.getTime()); + const diffDays = Math.ceil(diffTime / (1000 * 60 * 60 * 24)); + return diffDays; +}; + +interface ExperimentCardProps { + experiment: { + experiment_id: number | string; + name: string; + is_active: boolean; + last_trial_datetime_utc?: string; + arms: T[]; + }; + calculateProgressValue: (arm: T, maxValue?: number) => number; + formatDisplayValue: (arm: T) => string; + maxValue?: number; +} + +export function BaseExperimentCard({ experiment, - methodType, -}: { - experiment: MABBeta | MABNormal | CMAB | BayesianAB; - methodType: MethodType; -}) { - if (methodType === "mab" && experiment.prior_type === "beta") { - const betaExperiment = experiment as MABBeta; - return ; - } else if (methodType === "mab" && experiment.prior_type === "normal") { - const normalExperiment = experiment as MABNormal; - return ; - } else if (methodType === "cmab") { - const cmabExperiment = experiment as CMAB; - return ; - } else if (methodType === "bayes_ab") { - const bayesExperiment = experiment as BayesianAB; - return ; - } + calculateProgressValue, + formatDisplayValue, + maxValue, +}: ExperimentCardProps) { + const { experiment_id, name, is_active, arms } = experiment; + const router = useRouter(); - // Default case for other experiment types return ( - - - Unsupported Experiment Type - - This experiment type is not yet supported. - - - +
+ { + router.push(`/experiments/${experiment_id}`); + }} + > + +
+ {name} + + ID: {experiment_id} + +
+
+
+ + {is_active ? "Active" : "Not Active"} + +
+ + +
+
+ {arms && + arms.map((arm, index) => ( +
+
{arm.name}
+ +
+ {formatDisplayValue(arm)} +
+
+ ))} +
+
+ + Last Run: + + + {experiment.last_trial_datetime_utc + ? `${calculateDaysAgo( + experiment.last_trial_datetime_utc + )} days ago` + : "N/A"} + +
+
+
+ +
+ ); +} + +export function ExperimentCard({ experiment }: { experiment: ExperimentState }) { + const maxValue = !isMABExperimentStateBeta(experiment) + ? Math.max( + ...experiment.arms.map((arm: Arm) => + arm.mu ? arm.mu.reduce((a, b) => a + b, 0) : 0 + ), + 0 + ) + : 0; + + return isMABExperimentStateBeta(experiment) ? ( + + arm.alpha && arm.beta ? (arm.alpha * 100) / (arm.alpha + arm.beta) : 0 + } + formatDisplayValue={(arm: Arm) => + arm.alpha && arm.beta + ? `${((arm.alpha * 100) / (arm.alpha + arm.beta)).toFixed(1)}%` + : "N/A" + } + /> + ) : ( + + maxValue && arm.mu + ? arm.mu.reduce((a, b) => (a + b) * 100, 0) / (maxValue * 1.5) + : 0 + } + formatDisplayValue={(arm: Arm) => + arm.mu ? `${arm.mu.reduce((a, b) => a + b, 0).toFixed(1)}` : "N/A" + } + maxValue={maxValue} + /> ); } diff --git a/frontend/src/app/(protected)/experiments/components/cards/createBayesABCard.tsx b/frontend/src/app/(protected)/experiments/components/cards/createBayesABCard.tsx deleted file mode 100644 index 5cab8d5..0000000 --- a/frontend/src/app/(protected)/experiments/components/cards/createBayesABCard.tsx +++ /dev/null @@ -1,64 +0,0 @@ -import { BayesianAB } from "../../types"; -import { - Card, - CardContent, - CardHeader, - CardTitle, - CardDescription, -} from "@/components/ui/card"; - -export function BayesianABCards({ experiment }: { experiment: BayesianAB }) { - const { experiment_id, name, is_active, arms } = { ...experiment }; - - return ( -
- { - console.log("Details page not built yet"); - }} - > - -
- {name} - - ID: {experiment_id} - -
-
-
- - {is_active ? "Active" : "Not Active"} - -
- - -
-
- {arms && - arms.map((dist, index) => ( -
-
{dist.name}
-
-
-
-
-
- ))} -
-
-
- Last Run: 2 days ago -
-
-
- - -
- ); -} diff --git a/frontend/src/app/(protected)/experiments/components/cards/createCMABCard.tsx b/frontend/src/app/(protected)/experiments/components/cards/createCMABCard.tsx deleted file mode 100644 index 47b4d3d..0000000 --- a/frontend/src/app/(protected)/experiments/components/cards/createCMABCard.tsx +++ /dev/null @@ -1,75 +0,0 @@ -import { CMAB } from "../../types"; -import { - Card, - CardContent, - CardHeader, - CardTitle, - CardDescription, -} from "@/components/ui/card"; - -export function CMABCards({ experiment }: { experiment: CMAB }) { - const { experiment_id, name, is_active, arms, contexts } = { ...experiment }; - - return ( -
- { - e.stopPropagation(); - console.log("Details page not built yet"); - }} - > - -
- {name} - - ID: {experiment_id} - -
-
-
- - {is_active ? "Active" : "Not Active"} - -
- - -
-
- {arms && - arms.map((dist, index) => ( -
-
{dist.name}
-
-
-
-
-
- ))} - {contexts && - contexts.map((context, index) => ( -
-
{context.name}
-
-
-
-
-
- ))} -
-
-
- Last Run: 2 days ago -
-
-
- - -
- ); -} diff --git a/frontend/src/app/(protected)/experiments/components/cards/createMABCard.tsx b/frontend/src/app/(protected)/experiments/components/cards/createMABCard.tsx deleted file mode 100644 index 07679d3..0000000 --- a/frontend/src/app/(protected)/experiments/components/cards/createMABCard.tsx +++ /dev/null @@ -1,128 +0,0 @@ -import { MABBeta, MABNormal, MABArmBeta, MABArmNormal } from "../../types"; -import { - Card, - CardContent, - CardHeader, - CardTitle, - CardDescription, -} from "@/components/ui/card"; -import { Progress } from "@/components/ui/progress"; -import { useRouter } from "next/navigation"; - -const calculateDaysAgo = (dateString: string) => { - const date = new Date(dateString); - const now = new Date(); - const diffTime = Math.abs(now.getTime() - date.getTime()); - const diffDays = Math.ceil(diffTime / (1000 * 60 * 60 * 24)); - return diffDays; -}; - -interface ExperimentCardProps { - experiment: { - experiment_id: number | string; - name: string; - is_active: boolean; - last_trial_datetime_utc?: string; - arms: T[]; - }; - calculateProgressValue: (arm: T, maxValue?: number) => number; - formatDisplayValue: (arm: T) => string; - maxValue?: number; -} - -export function ExperimentCard({ - experiment, - calculateProgressValue, - formatDisplayValue, - maxValue, -}: ExperimentCardProps) { - const { experiment_id, name, is_active, arms } = experiment; - const router = useRouter(); - - return ( -
- { - router.push(`/experiments/${experiment_id}`); - }} - > - -
- {name} - - ID: {experiment_id} - -
-
-
- - {is_active ? "Active" : "Not Active"} - -
- - -
-
- {arms && - arms.map((arm, index) => ( -
-
{arm.name}
- -
- {formatDisplayValue(arm)} -
-
- ))} -
-
- - Last Run: - - - {experiment.last_trial_datetime_utc - ? `${calculateDaysAgo( - experiment.last_trial_datetime_utc - )} days ago` - : "N/A"} - -
-
-
- -
- ); -} - -export function MABBetaCards({ experiment }: { experiment: MABBeta }) { - return ( - - (arm.alpha * 100) / (arm.alpha + arm.beta) - } - formatDisplayValue={(arm: MABArmBeta) => - `${((arm.alpha * 100) / (arm.alpha + arm.beta)).toFixed(1)}%` - } - /> - ); -} - -export function MABNormalCards({ experiment }: { experiment: MABNormal }) { - const maxValue = Math.max(...experiment.arms.map((arm) => arm.mu), 0); - return ( - - maxValue ? (arm.mu * 100) / (maxValue * 1.5) : 0 - } - formatDisplayValue={(arm: MABArmNormal) => `${arm.mu.toFixed(1)}`} - maxValue={maxValue} - /> - ); -} diff --git a/frontend/src/app/(protected)/experiments/page.tsx b/frontend/src/app/(protected)/experiments/page.tsx index 2f222ab..a37f4a1 100644 --- a/frontend/src/app/(protected)/experiments/page.tsx +++ b/frontend/src/app/(protected)/experiments/page.tsx @@ -2,12 +2,10 @@ import React, { useEffect } from "react"; import EmptyPage from "./components/EmptyPage"; import { - getAllMABExperiments, - getAllCMABExperiments, - getAllBayesianABExperiments, + getExperimentsByType, } from "./api"; -import { MABBeta, MABNormal, CMAB, BayesianAB, MethodType } from "./types"; -import ExperimentCard from "./components/ExperimentCard"; +import { ExperimentState, MethodType } from "./types"; +import { ExperimentCard } from "./components/ExperimentCard"; import Hourglass from "@/components/Hourglass"; import FloatingAddButton from "./components/FloatingAddButton"; import Link from "next/link"; @@ -16,11 +14,9 @@ import { DividerWithTitle } from "@/components/Dividers"; export default function Experiments() { const [haveExperiments, setHaveExperiments] = React.useState(false); - const [mabExperiments, setMABExperiments] = React.useState([]); - const [cmabExperiments, setCMABExperiments] = React.useState([]); - const [bayesExperiments, setBayesExperiments] = React.useState( - [] - ); + const [mabExperiments, setMABExperiments] = React.useState([]); + const [cmabExperiments, setCMABExperiments] = React.useState([]); + const [bayesExperiments, setBayesExperiments] = React.useState([]); const [loading, setLoading] = React.useState(true); const [loadingError, setLoadingError] = React.useState(""); @@ -33,9 +29,9 @@ export default function Experiments() { const fetchData = async () => { try { const [mabData, cmabData, bayesabData] = await Promise.all([ - getAllMABExperiments(token), - getAllCMABExperiments(token), - getAllBayesianABExperiments(token), + getExperimentsByType(token, "mab"), + getExperimentsByType(token, "cmab"), + getExperimentsByType(token, "bayes_ab"), ]); setMABExperiments(mabData); setCMABExperiments(cmabData); @@ -130,7 +126,7 @@ const ExperimentCardGrid = ({ experiments, methodType, }: { - experiments: MABBeta[] | MABNormal[] | CMAB[] | BayesianAB[]; + experiments: ExperimentState[]; methodType: MethodType; }) => { return ( @@ -140,7 +136,7 @@ const ExperimentCardGrid = ({ > {experiments.map((experiment) => (
  • - +
  • ))} diff --git a/frontend/src/app/(protected)/experiments/store/useExperimentStore.ts b/frontend/src/app/(protected)/experiments/store/useExperimentStore.ts index 53ad153..3c5a02b 100644 --- a/frontend/src/app/(protected)/experiments/store/useExperimentStore.ts +++ b/frontend/src/app/(protected)/experiments/store/useExperimentStore.ts @@ -2,50 +2,31 @@ import { create } from "zustand"; import { persist, createJSONStorage } from "zustand/middleware"; import type { ExperimentState, - CMABExperimentState, - NewCMABArm, - MABExperimentStateBeta, - MABExperimentStateNormal, - BayesianABState, + NewExperimentState, + NewArm, PriorType, RewardType, MethodType, - NewMABArmBeta, - NewMABArmNormal, - NewBayesianABArm, NewContext, Notifications } from "../types"; -// Type guards for better type safety -export function isMABExperimentStateBeta( - state: ExperimentState -): state is MABExperimentStateBeta { - return state.methodType === "mab" && state.prior_type === "beta"; -} - -export function isMABExperimentStateNormal( - state: ExperimentState -): state is MABExperimentStateNormal { - return state.methodType === "mab" && state.prior_type === "normal"; -} +export const isMABExperimentStateBeta = (experimentState: NewExperimentState) => { + return (experimentState.prior_type === "beta" && experimentState.exp_type === "mab"); +}; -export function isCMABExperimentState( - state: ExperimentState -): state is CMABExperimentState { - return state.methodType === "cmab"; -} +export const isCMABExperimentState = (experimentState: NewExperimentState) => { + return experimentState.exp_type === "cmab"; +}; -export function isBayesianABState( - state: ExperimentState -): state is BayesianABState { - return state.methodType === "bayes_ab"; -} +export const isBayesianABState = (experimentState: NewExperimentState) => { + return experimentState.exp_type === "bayes_ab"; +}; // Define store interface ExperimentStore { - experimentState: ExperimentState; + experimentState: NewExperimentState; // basicInfoPage updateName: (name: string) => void; @@ -62,11 +43,11 @@ interface ExperimentStore { // Arms updates updateArms: ( - arms: NewMABArmBeta[] | NewMABArmNormal[] | NewCMABArm[] | NewBayesianABArm[] + arms: NewArm[] ) => void; updateArm: ( index: number, - arm: Partial + arm: Partial ) => void; addArm: () => void; removeArm: (index: number) => void; @@ -86,7 +67,7 @@ interface ExperimentStore { resetState: () => void; } -const createInitialState = (): ExperimentState => { +const createInitialState = (): NewExperimentState => { const baseDescr = { name: "", description: "", @@ -95,13 +76,13 @@ const createInitialState = (): ExperimentState => { auto_fail_value: 10, auto_fail_unit: "days", }; - const methodType: MethodType = "mab"; + const exp_type: MethodType = "mab"; const prior_type: PriorType = "beta"; const reward_type: RewardType = "binary"; - const baseMABState = { + const baseState = { ...baseDescr, - methodType, + exp_type, reward_type, prior_type, notifications: { @@ -115,22 +96,28 @@ const createInitialState = (): ExperimentState => { }; return { - ...baseMABState, + ...baseState, + last_trial_datetime_utc: null, arms: [ { name: "", description: "", alpha_init: 1, beta_init: 1, - } as NewMABArmBeta, + mu_init: 0, + sigma_init: 1, + } as NewArm, { name: "", description: "", alpha_init: 1, beta_init: 1, - } as NewMABArmBeta, + mu_init: 0, + sigma_init: 1, + } as NewArm, ], - } as MABExperimentStateBeta; + contexts: [], + } as NewExperimentState; }; export const useExperimentStore = create()( @@ -182,17 +169,17 @@ export const useExperimentStore = create()( updateMethodType: (newMethodType: MethodType) => set((state) => { const { experimentState } = state; - if (newMethodType === experimentState.methodType) + if (newMethodType === experimentState.exp_type) return { experimentState }; const { reward_type, notifications } = experimentState; - let newState: ExperimentState; + let newState: NewExperimentState; if (newMethodType == "mab") { newState = { ...experimentState, - methodType: newMethodType, + exp_type: newMethodType, prior_type: "beta", reward_type, notifications, @@ -202,15 +189,15 @@ export const useExperimentStore = create()( description: "", alpha_init: 1, beta_init: 1, - } as NewMABArmBeta, + } as NewArm, { name: "", description: "", alpha_init: 1, beta_init: 1, - } as NewMABArmBeta, + } as NewArm, ], - } as MABExperimentStateBeta; + } as NewExperimentState; } else if (newMethodType == "cmab") { newState = { ...experimentState, @@ -224,13 +211,13 @@ export const useExperimentStore = create()( description: "", mu_init: 0, sigma_init: 1, - } as NewCMABArm, + } as NewArm, { name: "", description: "", mu_init: 0, sigma_init: 1, - } as NewCMABArm, + } as NewArm, ], contexts: [ { @@ -239,7 +226,7 @@ export const useExperimentStore = create()( value_type: "binary", } as NewContext, ], - } as CMABExperimentState; + } as NewExperimentState; } else if (newMethodType == "bayes_ab") { newState = { ...experimentState, @@ -254,16 +241,16 @@ export const useExperimentStore = create()( mu_init: 0, sigma_init: 1, is_treatment_arm: true, - } as NewBayesianABArm, + } as NewArm, { name: "", description: "", mu_init: 0, sigma_init: 1, is_treatment_arm: false, - } as NewBayesianABArm, + } as NewArm, ], - } as BayesianABState; + } as NewExperimentState; } else { throw new Error("Invalid method type"); } @@ -280,10 +267,10 @@ export const useExperimentStore = create()( return { experimentState }; // Create new state based on prior type - let newState: ExperimentState; + let newState: NewExperimentState; const baseArm = { name: "", description: "" }; - if (experimentState.methodType === "mab") { + if (experimentState.exp_type === "mab") { if (newPriorType === "beta") { newState = { ...experimentState, @@ -292,8 +279,8 @@ export const useExperimentStore = create()( ...baseArm, alpha_init: 1, beta_init: 1, - })) as NewMABArmBeta[], - } as MABExperimentStateBeta; + })) as NewArm[], + } as NewExperimentState; } else { newState = { ...experimentState, @@ -302,30 +289,30 @@ export const useExperimentStore = create()( ...baseArm, mu_init: 0, sigma_init: 1, - })) as NewMABArmNormal[], - } as MABExperimentStateNormal; + })) as NewArm[], + } as NewExperimentState; } - } else if (experimentState.methodType === "cmab") { + } else if (experimentState.exp_type === "cmab") { newState = { ...experimentState, - priorType: "normal", + prior_type: "normal", arms: experimentState.arms.map(() => ({ ...baseArm, mu_init: 0, sigma_init: 1, - })) as NewCMABArm[], - contexts: (experimentState as CMABExperimentState).contexts, - } as CMABExperimentState; - } else if (experimentState.methodType === "bayes_ab"){ + })) as NewArm[], + contexts: (experimentState as NewExperimentState).contexts, + } as NewExperimentState; + } else if (experimentState.exp_type === "bayes_ab"){ newState = { ...experimentState, - priorType: newPriorType, + prior_type: newPriorType, arms: experimentState.arms.map(() => ({ ...baseArm, mu_init: 0, sigma_init: 1, - })) as NewBayesianABArm[], - } as BayesianABState; + })) as NewArm[], + } as NewExperimentState; } else { throw new Error("Invalid method type"); } @@ -344,97 +331,36 @@ export const useExperimentStore = create()( // ------------ Arms updates ------------ updateArms: ( - newArms: NewMABArmBeta[] | NewMABArmNormal[] | NewCMABArm[] | NewBayesianABArm[] + newArms: NewArm[] ) => set((state) => { - const { experimentState } = state; - if (isMABExperimentStateBeta(experimentState)) { - const validatedArms = newArms as NewMABArmBeta[]; - const updatedState: MABExperimentStateBeta = { - ...experimentState, - arms: validatedArms, - }; - return { experimentState: updatedState }; - } else if (isMABExperimentStateNormal(experimentState)) { - const validatedArms = newArms as NewMABArmNormal[]; - const updatedState: MABExperimentStateNormal = { - ...experimentState, - arms: validatedArms, - }; - return { experimentState: updatedState }; - } else if (isCMABExperimentState(experimentState)) { - const validatedArms = newArms as NewCMABArm[]; - const updatedState: CMABExperimentState = { - ...experimentState, - arms: validatedArms, - }; - return { experimentState: updatedState }; - } else if (isBayesianABState(experimentState)) { - const validatedArms = newArms as NewBayesianABArm[]; - const updatedState: BayesianABState = { - ...experimentState, + const validatedArms = newArms as NewArm[]; + return { + experimentState: { + ...state.experimentState, arms: validatedArms, - }; - return { experimentState: updatedState }; - } else { - throw new Error("Invalid method type") - } + }, + }; }), updateArm: ( index: number, - armUpdate: Partial< - NewMABArmBeta | NewMABArmNormal | NewCMABArm | NewBayesianABArm - > + armUpdate: Partial ) => set((state) => { - if (isMABExperimentStateBeta(state.experimentState)) { - const newArms = JSON.parse( - JSON.stringify(state.experimentState.arms) - ) as NewMABArmBeta[]; - newArms[index] = { - ...newArms[index], - ...(armUpdate as Partial), - }; - return { - experimentState: { ...state.experimentState, arms: newArms }, - }; - } else if (isMABExperimentStateNormal(state.experimentState)) { - const newArms = JSON.parse( - JSON.stringify(state.experimentState.arms) - ) as NewMABArmNormal[]; - newArms[index] = { - ...newArms[index], - ...(armUpdate as Partial), - }; - return { - experimentState: { ...state.experimentState, arms: newArms }, - }; - } else if (isCMABExperimentState(state.experimentState)) { - const newArms = JSON.parse( - JSON.stringify(state.experimentState.arms) - ) as NewCMABArm[]; - newArms[index] = { - ...newArms[index], - ...(armUpdate as Partial), - }; - return { - experimentState: { ...state.experimentState, arms: newArms }, - }; - } else if (isBayesianABState(state.experimentState)) { - const newArms = JSON.parse( - JSON.stringify(state.experimentState.arms) - ) as NewBayesianABArm[]; - newArms[index] = { - ...newArms[index], - ...(armUpdate as Partial), - }; - return { - experimentState: { ...state.experimentState, arms: newArms }, - }; - } else { - throw new Error("Invalid method type"); - } + const newArms = JSON.parse( + JSON.stringify(state.experimentState.arms) + ) as NewArm[]; + newArms[index] = { + ...newArms[index], + ...(armUpdate as Partial), + }; + return { + experimentState: { + ...state.experimentState, + arms: newArms, + }, + }; }), addArm: () => @@ -446,43 +372,29 @@ export const useExperimentStore = create()( description: "", alpha_init: 1, beta_init: 1, - } as NewMABArmBeta; - return { - experimentState: { - ...experimentState, - arms: [...experimentState.arms, newArm], - }, - }; - } else if (isMABExperimentStateNormal(experimentState)) { - const newArm = { - name: "", - description: "", - mu_init: 0, - sigma_init: 1, - } as NewMABArmNormal; + } as NewArm; return { experimentState: { ...experimentState, arms: [...experimentState.arms, newArm], }, }; - } else if (isCMABExperimentState(experimentState)) { + } else if (isBayesianABState(experimentState)) { + throw new Error("Adding arms for Bayesian A/B experiments is not currently supported"); + } else { const newArm = { name: "", description: "", mu_init: 0, sigma_init: 1, - } as NewCMABArm; + } as NewArm; return { experimentState: { ...experimentState, arms: [...experimentState.arms, newArm], }, }; - } else if (isBayesianABState(experimentState)){ - throw new Error("Adding arms for Bayesian A/B experiments is not currently supported"); } - return { experimentState }; // Return original state for any other case }), removeArm: (index: number) => @@ -492,37 +404,12 @@ export const useExperimentStore = create()( const newArms = [...experimentState.arms]; newArms.splice(index, 1); - if (isMABExperimentStateBeta(experimentState)) { - return { - experimentState: { - ...experimentState, - arms: newArms as NewMABArmBeta[], - }, - }; - } else if (isMABExperimentStateNormal(experimentState)) { - return { - experimentState: { - ...experimentState, - arms: newArms as NewMABArmNormal[], - }, - }; - } else if (isCMABExperimentState(experimentState)) { - return { - experimentState: { - ...experimentState, - arms: newArms as NewCMABArm[], - }, - }; - } else if (isBayesianABState(experimentState)) { - return { - experimentState: { - ...experimentState, - arms: newArms as NewBayesianABArm[], - }, - }; - } else { - throw new Error("Invalid method type"); - } + return { + experimentState: { + ...experimentState, + arms: newArms as NewArm[], + }, + }; }), // ------------ Context updates ------------ @@ -536,7 +423,7 @@ export const useExperimentStore = create()( experimentState: { ...experimentState, contexts: newContexts, - } as CMABExperimentState, + } as NewExperimentState, }; }), @@ -546,7 +433,7 @@ export const useExperimentStore = create()( if (!isCMABExperimentState(experimentState)) return { experimentState }; - const newContexts = [...experimentState.contexts]; + const newContexts = [...(experimentState.contexts || [])]; newContexts[index] = { ...newContexts[index], ...contextUpdate }; return { @@ -572,7 +459,7 @@ export const useExperimentStore = create()( return { experimentState: { ...experimentState, - contexts: [...experimentState.contexts, newContext], + contexts: [...(experimentState.contexts || []), newContext], }, }; }), @@ -582,6 +469,7 @@ export const useExperimentStore = create()( const { experimentState } = state; if ( !isCMABExperimentState(experimentState) || + !experimentState.contexts || experimentState.contexts.length <= 1 ) return { experimentState }; diff --git a/frontend/src/app/(protected)/experiments/types.ts b/frontend/src/app/(protected)/experiments/types.ts index 3a3c0a2..d163323 100644 --- a/frontend/src/app/(protected)/experiments/types.ts +++ b/frontend/src/app/(protected)/experiments/types.ts @@ -11,8 +11,8 @@ interface BetaParams { interface GaussianParams { name: string; - mu: number; - sigma: number; + mu: Array; + covariance: Array; } interface StepComponentProps { @@ -46,7 +46,7 @@ interface Context extends NewContext { interface ExperimentStateBase { name: string; description: string; - methodType: MethodType; + exp_type: MethodType; prior_type: PriorType; reward_type: RewardType; sticky_assignment: boolean; @@ -55,148 +55,55 @@ interface ExperimentStateBase { auto_fail_unit: "days" | "hours"; } -interface ArmBase { +interface NewArm { name: string; description: string; + mu_init?: number; + sigma_init?: number; + alpha_init?: number; + beta_init?: number; + is_treatment_arm?: boolean; } interface StepValidation { isValid: boolean; errors: Record | Record[]; } -// ----- Bayesian AB -interface NewBayesianABArm extends ArmBase { - mu_init: number; - sigma_init: number; - is_treatment_arm: boolean; -} - -interface BayesianABArm extends NewBayesianABArm { +interface Arm extends NewArm { arm_id: number; - mu: number; - sigma: number; + alpha?: number; + beta?: number; + mu?: number[]; + covariance?: number[][]; } -interface BayesianABState extends ExperimentStateBase { - methodType: "bayes_ab"; - arms: NewBayesianABArm[]; +interface NewExperimentState extends ExperimentStateBase { + arms: NewArm[]; notifications: Notifications; + contexts?: NewContext[]; } -interface BayesianAB extends BayesianABState { - experiment_id: number; - is_active: boolean; - arms: BayesianABArm[]; -} - -// ----- MAB - -interface NewMABArmBeta extends ArmBase { - alpha_init: number; - beta_init: number; -} - -interface NewMABArmNormal extends ArmBase { - mu_init: number; - sigma_init: number; -} - -interface MABArmBeta extends NewMABArmBeta { - arm_id: number; - alpha: number; - beta: number; -} - -interface MABArmNormal extends NewMABArmNormal { - arm_id: number; - mu: number; - sigma: number; -} - -interface MABExperimentStateNormal extends ExperimentStateBase { - methodType: "mab"; - arms: NewMABArmNormal[]; - notifications: Notifications; -} - -interface MABExperimentStateBeta extends ExperimentStateBase { - methodType: "mab"; - arms: NewMABArmBeta[]; - notifications: Notifications; -} - -interface MABNormal extends MABExperimentStateNormal { +interface ExperimentState extends NewExperimentState { experiment_id: number; is_active: boolean; last_trial_datetime_utc: string; - arms: MABArmNormal[]; -} - -interface MABBeta extends MABExperimentStateBeta { - experiment_id: number; - is_active: boolean; - last_trial_datetime_utc: string; - arms: MABArmBeta[]; -} - -// ----- CMAB - -interface NewCMABArm extends ArmBase { - mu_init: number; - sigma_init: number; -} - -interface CMABArm extends NewCMABArm { - arm_id: number; - mu: number[]; - sigma: number[]; -} - -interface CMABExperimentState extends ExperimentStateBase { - methodType: "cmab"; - arms: NewCMABArm[]; - contexts: NewContext[]; - notifications: Notifications; -} - -interface CMAB extends CMABExperimentState { - experiment_id: number; - is_active: boolean; - arms: CMABArm[]; + arms: Arm[]; + contexts?: Context[]; } -type ExperimentState = - | MABExperimentStateNormal - | MABExperimentStateBeta - | CMABExperimentState - | BayesianABState; export type { - BayesianAB, - BayesianABArm, - BayesianABState, - ArmBase, + Arm, BetaParams, - CMAB, - CMABArm, - CMABExperimentState, Context, ExperimentState, ExperimentStateBase, GaussianParams, - MABBeta, - MABNormal, - MABArmBeta, - MABArmNormal, - MABExperimentStateBeta, - MABExperimentStateNormal, MethodType, - NewBayesianABArm, - NewCMABArm, + NewArm, + NewExperimentState, NewContext, - NewMABArmBeta, - NewMABArmNormal, Notifications, PriorType, RewardType, From c8bb88d58eed6bae904e4201f8ed97f6f5008712 Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Mon, 9 Jun 2025 20:02:40 +0300 Subject: [PATCH 62/74] fix experiment viz --- .../{MABArmsProgress.tsx => ArmsProgress.tsx} | 10 +++--- .../{MABChart.tsx => ExperimentChart.tsx} | 16 +++++----- .../experiments/[experimentId]/page.tsx | 32 ++++++++++++------- .../experiments/[experimentId]/types.ts | 27 ++++++++-------- .../src/app/(protected)/experiments/api.ts | 2 +- .../src/app/(protected)/experiments/types.ts | 7 ++-- 6 files changed, 54 insertions(+), 40 deletions(-) rename frontend/src/app/(protected)/experiments/[experimentId]/components/{MABArmsProgress.tsx => ArmsProgress.tsx} (88%) rename frontend/src/app/(protected)/experiments/[experimentId]/components/{MABChart.tsx => ExperimentChart.tsx} (85%) diff --git a/frontend/src/app/(protected)/experiments/[experimentId]/components/MABArmsProgress.tsx b/frontend/src/app/(protected)/experiments/[experimentId]/components/ArmsProgress.tsx similarity index 88% rename from frontend/src/app/(protected)/experiments/[experimentId]/components/MABArmsProgress.tsx rename to frontend/src/app/(protected)/experiments/[experimentId]/components/ArmsProgress.tsx index 6cb7221..7338131 100644 --- a/frontend/src/app/(protected)/experiments/[experimentId]/components/MABArmsProgress.tsx +++ b/frontend/src/app/(protected)/experiments/[experimentId]/components/ArmsProgress.tsx @@ -22,7 +22,7 @@ export default function MABArmsProgress({ }: { armsData: MABArmDetails[]; }) { - const maxMu = Math.max(...armsData.map((arm) => arm.mu)); + const maxMu = Math.max(...armsData.map((arm) => arm.mu ? arm.mu[0] : 0)); return ( @@ -56,18 +56,18 @@ export default function MABArmsProgress({ {arm.n_outcomes}
    - {arm.beta + {(arm.alpha && arm.beta) ? `${((arm.alpha * 100) / (arm.alpha + arm.beta)).toFixed( 1 )}%` - : `${arm.mu.toFixed(1)}`} + : arm.mu ? (`${arm.mu[0].toFixed(1)}`) : (``)}
    diff --git a/frontend/src/app/(protected)/experiments/[experimentId]/components/MABChart.tsx b/frontend/src/app/(protected)/experiments/[experimentId]/components/ExperimentChart.tsx similarity index 85% rename from frontend/src/app/(protected)/experiments/[experimentId]/components/MABChart.tsx rename to frontend/src/app/(protected)/experiments/[experimentId]/components/ExperimentChart.tsx index 6c58054..8b88d85 100644 --- a/frontend/src/app/(protected)/experiments/[experimentId]/components/MABChart.tsx +++ b/frontend/src/app/(protected)/experiments/[experimentId]/components/ExperimentChart.tsx @@ -24,27 +24,27 @@ export default function MABChart({ const priorBetaData = experimentData.arms.map((arm) => ({ name: arm.name, - alpha: arm.alpha_init, - beta: arm.beta_init, + alpha: arm.alpha_init ? arm.alpha_init : 1, + beta: arm.beta_init ? arm.beta_init : 1, })); const posteriorBetaData = experimentData.arms.map((arm) => ({ name: arm.name, - alpha: arm.alpha, - beta: arm.beta, + alpha: arm.alpha ? arm.alpha : 1, + beta: arm.beta ? arm.beta : 1, })); const priorGaussianData = experimentData.arms.map((arm) => ({ name: arm.name, - mu: arm.mu_init, - sigma: arm.sigma_init, + mu: [arm.mu_init ? arm.mu_init : 0], + covariance: [[arm.sigma_init ? arm.sigma_init : 1]], })); const posteriorGaussianData = experimentData.arms.map((arm) => ({ name: arm.name, - mu: arm.mu, - sigma: arm.sigma, + mu: arm.mu ? arm.mu : [0], + covariance: arm.covariance ? arm.covariance : [[1]], })); return ( diff --git a/frontend/src/app/(protected)/experiments/[experimentId]/page.tsx b/frontend/src/app/(protected)/experiments/[experimentId]/page.tsx index 809c250..68e2a09 100644 --- a/frontend/src/app/(protected)/experiments/[experimentId]/page.tsx +++ b/frontend/src/app/(protected)/experiments/[experimentId]/page.tsx @@ -2,12 +2,12 @@ import { useState, useEffect } from "react"; import { Badge } from "@/components/ui/badge"; -import MABChart from "./components/MABChart"; -import MABArmsProgress from "./components/MABArmsProgress"; +import MABChart from "./components/ExperimentChart"; +import MABArmsProgress from "./components/ArmsProgress"; import NotificationDetails from "./components/Notifications"; import ExtraInfo from "./components/ExtraInfo"; -import { getMABExperimentById } from "../api"; +import { getExperimentById } from "../api"; import { useParams } from "next/navigation"; import { useAuth } from "@/utils/auth"; import { @@ -20,17 +20,17 @@ import { } from "@/components/ui/breadcrumb"; import { - MABExperimentDetails, - MABArmDetails, + SingleExperimentDetails, + ArmDetails, Notification, ExtraInfo as ExtraInfoType, } from "./types"; export default function ExperimentDetails() { const { experimentId } = useParams(); - const [armsDetails, setArmsDetails] = useState([]); + const [armsDetails, setArmsDetails] = useState([]); const [experimentDetails, setExperimentDetails] = - useState(null); + useState(null); const [notificationData, setNotificationData] = useState([]); const [extraInfo, setExtraInfo] = useState(null); @@ -40,10 +40,13 @@ export default function ExperimentDetails() { useEffect(() => { if (!token) return; - getMABExperimentById(token, Number(experimentId)).then((data) => { + getExperimentById(token, Number(experimentId)).then((data) => { setArmsDetails(data.arms); - setExperimentDetails(data); - setNotificationData(data.notifications); + setExperimentDetails({ + ...data, + notifications: Array.isArray(data.notifications) ? data.notifications : [] + }); + setNotificationData(Array.isArray(data.notifications) ? data.notifications : []); setExtraInfo({ dateCreated: data.created_datetime_utc, lastTrialDate: data.last_trial_datetime_utc, @@ -54,6 +57,7 @@ export default function ExperimentDetails() { }, [experimentId, token]); return ( + experimentDetails?.exp_type == "mab" ? (
    @@ -113,5 +117,11 @@ export default function ExperimentDetails() {
    - ); +) : ( +
    +

    + We're working on visualizing details for {experimentDetails?.exp_type} experiments. Stay tuned! +

    +
    + )); } diff --git a/frontend/src/app/(protected)/experiments/[experimentId]/types.ts b/frontend/src/app/(protected)/experiments/[experimentId]/types.ts index 71c66fd..4dd65df 100644 --- a/frontend/src/app/(protected)/experiments/[experimentId]/types.ts +++ b/frontend/src/app/(protected)/experiments/[experimentId]/types.ts @@ -1,28 +1,29 @@ -interface MABExperimentDetails { +interface SingleExperimentDetails { name: string; description: string; - reward: string; + reward_type: string; prior_type: string; + exp_type: string; is_active: boolean; experiment_id: number; created_datetime_utc: string; last_trial_datetime_utc: string; n_trials: number; - arms: MABArmDetails[]; + arms: ArmDetails[]; notifications: Notification[]; } -interface MABArmDetails { +interface ArmDetails { name: string; description: string; - alpha_init: number; - beta_init: number; - mu_init: number; - sigma_init: number; - alpha: number; - beta: number; - mu: number; - sigma: number; + alpha_init?: number; + beta_init?: number; + mu_init?: number; + sigma_init?: number; + alpha?: number; + beta?: number; + mu?: number[]; + covariance?: number[][]; arm_id: number; n_outcomes: number; } @@ -41,4 +42,4 @@ interface ExtraInfo { nTrials: number; } -export type { MABExperimentDetails, MABArmDetails, Notification, ExtraInfo }; +export type { SingleExperimentDetails, ArmDetails, Notification, ExtraInfo }; diff --git a/frontend/src/app/(protected)/experiments/api.ts b/frontend/src/app/(protected)/experiments/api.ts index 32d8fd6..c877cff 100644 --- a/frontend/src/app/(protected)/experiments/api.ts +++ b/frontend/src/app/(protected)/experiments/api.ts @@ -44,7 +44,7 @@ const getExperimentsByType = async (token: string | null, exp_type: string) => { const getExperimentById = async (token: string | null, id: number) => { try { - const response = await api.get(`/experiment/${id}/`, { + const response = await api.get(`/experiment/id/${id}`, { headers: { Authorization: `Bearer ${token}`, }, diff --git a/frontend/src/app/(protected)/experiments/types.ts b/frontend/src/app/(protected)/experiments/types.ts index d163323..ba884c5 100644 --- a/frontend/src/app/(protected)/experiments/types.ts +++ b/frontend/src/app/(protected)/experiments/types.ts @@ -11,8 +11,8 @@ interface BetaParams { interface GaussianParams { name: string; - mu: Array; - covariance: Array; + mu: number[]; + covariance: number[][]; } interface StepComponentProps { @@ -76,6 +76,7 @@ interface Arm extends NewArm { beta?: number; mu?: number[]; covariance?: number[][]; + n_outcomes: number; } interface NewExperimentState extends ExperimentStateBase { @@ -88,6 +89,8 @@ interface ExperimentState extends NewExperimentState { experiment_id: number; is_active: boolean; last_trial_datetime_utc: string; + created_datetime_utc: string; + n_trials: number; arms: Arm[]; contexts?: Context[]; } From 4e0078e030e66e89db581ddb080cf497d4ae8ef7 Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Mon, 9 Jun 2025 21:35:39 +0300 Subject: [PATCH 63/74] debug prior-reward config --- .../components/ArmsProgress.tsx | 4 +- .../[experimentId]/components/Charts.tsx | 6 +- .../components/ExperimentChart.tsx | 4 +- .../add/components/addPriorReward.tsx | 158 ++++++++++++++++++ .../experiments/add/components/basicInfo.tsx | 14 +- .../app/(protected)/experiments/add/page.tsx | 6 +- .../experiments/store/useExperimentStore.ts | 9 +- 7 files changed, 178 insertions(+), 23 deletions(-) create mode 100644 frontend/src/app/(protected)/experiments/add/components/addPriorReward.tsx diff --git a/frontend/src/app/(protected)/experiments/[experimentId]/components/ArmsProgress.tsx b/frontend/src/app/(protected)/experiments/[experimentId]/components/ArmsProgress.tsx index 7338131..c9017d9 100644 --- a/frontend/src/app/(protected)/experiments/[experimentId]/components/ArmsProgress.tsx +++ b/frontend/src/app/(protected)/experiments/[experimentId]/components/ArmsProgress.tsx @@ -15,12 +15,12 @@ import { Info } from "lucide-react"; import { Progress } from "@/components/ui/progress"; import { Badge } from "@/components/ui/badge"; -import { MABArmDetails } from "../types"; +import { ArmDetails } from "../types"; export default function MABArmsProgress({ armsData, }: { - armsData: MABArmDetails[]; + armsData: ArmDetails[]; }) { const maxMu = Math.max(...armsData.map((arm) => arm.mu ? arm.mu[0] : 0)); return ( diff --git a/frontend/src/app/(protected)/experiments/[experimentId]/components/Charts.tsx b/frontend/src/app/(protected)/experiments/[experimentId]/components/Charts.tsx index c8697a0..1eacf52 100644 --- a/frontend/src/app/(protected)/experiments/[experimentId]/components/Charts.tsx +++ b/frontend/src/app/(protected)/experiments/[experimentId]/components/Charts.tsx @@ -160,15 +160,15 @@ const NormalLineChart = ({ const data = x.map((xVal) => { const point: { x: number; [key: string]: number } = { x: xVal }; - const posteriorPDFs = posteriors.map(({ mu, sigma }) => - normalPDF(xVal, mu, sigma) + const posteriorPDFs = posteriors.map(({ mu, covariance }) => + normalPDF(xVal, mu[0], covariance[0][0]) ); posteriors.forEach(({ name }, i) => { point[`Posterior - ${i}_${name}`] = posteriorPDFs[i]; }); - const priorPDFs = priors.map(({ mu, sigma }) => normalPDF(xVal, mu, sigma)); + const priorPDFs = priors.map(({ mu, covariance }) => normalPDF(xVal, mu[0], covariance[0][0])); priors.forEach(({ name }, i) => { point[`Prior - ${i}_${name}`] = priorPDFs[i]; diff --git a/frontend/src/app/(protected)/experiments/[experimentId]/components/ExperimentChart.tsx b/frontend/src/app/(protected)/experiments/[experimentId]/components/ExperimentChart.tsx index 8b88d85..de8e680 100644 --- a/frontend/src/app/(protected)/experiments/[experimentId]/components/ExperimentChart.tsx +++ b/frontend/src/app/(protected)/experiments/[experimentId]/components/ExperimentChart.tsx @@ -8,13 +8,13 @@ import { CardTitle, } from "@/components/ui/card"; import { Switch } from "@/components/ui/switch"; -import { MABExperimentDetails } from "../types"; +import { SingleExperimentDetails } from "../types"; import { BetaLineChart, NormalLineChart } from "./Charts"; export default function MABChart({ experimentData, }: { - experimentData: MABExperimentDetails | null; + experimentData: SingleExperimentDetails | null; }) { const [showPriors, setShowPriors] = useState(false); diff --git a/frontend/src/app/(protected)/experiments/add/components/addPriorReward.tsx b/frontend/src/app/(protected)/experiments/add/components/addPriorReward.tsx new file mode 100644 index 0000000..bc5257d --- /dev/null +++ b/frontend/src/app/(protected)/experiments/add/components/addPriorReward.tsx @@ -0,0 +1,158 @@ +import { useExperimentStore, isBayesianABState, isCMABExperimentState } from "../../store/useExperimentStore"; +import { useCallback, useState, useEffect } from "react"; +import { RadioGroup, RadioGroupItem } from "@/components/ui/radio-group"; +import { Label } from "@/components/ui/label"; +import type { PriorType, RewardType, StepComponentProps } from "../../types"; +import { DividerWithTitle } from "@/components/Dividers"; + +export default function PriorRewardSelection({ + onValidate, +}: StepComponentProps) { + const { experimentState, updatePriorType, updateRewardType } = + useExperimentStore(); + const [errors, setErrors] = useState({ + prior_type: "", + reward_type: "", + }); + + const validateForm = useCallback(() => { + let isValid = true; + const newErrors = { + prior_type: "", + reward_type: "", + }; + + if (!experimentState.prior_type) { + newErrors.prior_type = "Please select a prior type"; + isValid = false; + } + + if (!experimentState.reward_type) { + newErrors.reward_type = "Please select a reward type"; + isValid = false; + } + + if ( + experimentState.prior_type === "beta" && + experimentState.reward_type === "real-valued" + ) { + newErrors.reward_type = + "Beta prior is not compatible with real-valued reward"; + isValid = false; + } + + if ((isBayesianABState(experimentState) || isCMABExperimentState(experimentState)) && experimentState.prior_type === "beta") { + newErrors.prior_type = "Beta prior is not compatible with Bayesian AB or CMAB experiments"; + isValid = false; + } + + return { isValid, newErrors }; + }, [experimentState.prior_type, experimentState.reward_type]); + + useEffect(() => { + const { isValid, newErrors } = validateForm(); + if (JSON.stringify(newErrors) !== JSON.stringify(errors)) { + setErrors(newErrors); + onValidate({ isValid, errors: newErrors }); + } + }, [validateForm, onValidate, errors]); + + useEffect(() => { + const { isValid, newErrors } = validateForm(); + setErrors(newErrors); + onValidate({ isValid, errors: newErrors }); + }, []); + + return ( +
    +
    +

    + Configure Experiment Parameters +

    +
    +
    + +
    + + updatePriorType(value as PriorType)} + className="space-y-4" + > +
    + +
    + +

    + Gaussian distribution; best for real-valued outcomes. +

    +
    +
    + +
    + +
    + +

    + Beta distribution; best for binary outcomes. +

    +
    +
    +
    + {errors.prior_type ? ( +

    {errors.prior_type}

    + ) : ( +

     

    + )} +
    + + +
    + + updateRewardType(value as RewardType)} + className="space-y-4" + > +
    + +
    + +

    + E.g. how long someone engaged with your app, how long did + onboarding take, etc. +

    +
    +
    + +
    + +
    + +

    + E.g. whether a user clicked on a button, whether a user + converted, etc. +

    +
    +
    +
    + {errors.reward_type ? ( +

    + {errors.reward_type} +

    + ) : ( +

     

    + )} +
    +
    +
    + ); +} diff --git a/frontend/src/app/(protected)/experiments/add/components/basicInfo.tsx b/frontend/src/app/(protected)/experiments/add/components/basicInfo.tsx index 93bd293..b7a5b25 100644 --- a/frontend/src/app/(protected)/experiments/add/components/basicInfo.tsx +++ b/frontend/src/app/(protected)/experiments/add/components/basicInfo.tsx @@ -68,12 +68,12 @@ export default function AddBasicInfo({ const [errors, setErrors] = useState({ name: "", description: "", - methodType: "", + exp_type: "", }); const validateForm = useCallback(() => { let isValid = true; - const newErrors = { name: "", description: "", methodType: "" }; + const newErrors = { name: "", description: "", exp_type: "" }; if (!experimentState.name.trim()) { newErrors.name = "Experiment name is required"; @@ -85,8 +85,8 @@ export default function AddBasicInfo({ isValid = false; } - if (!experimentState.methodType) { - newErrors.methodType = "Please select an experiment type"; + if (!experimentState.exp_type) { + newErrors.exp_type = "Please select an experiment type"; isValid = false; } @@ -164,15 +164,15 @@ export default function AddBasicInfo({ description={methodInfo[method].description} infoTitle={methodInfo[method].infoTitle} infoDescription={methodInfo[method].infoDescription} - selected={experimentState.methodType === method} + selected={experimentState.exp_type === method} disabled={methodInfo[method].disabled} onClick={() => updateMethodType(method as keyof Methods)} /> ) )}
    - {errors.methodType ? ( -

    {errors.methodType}

    + {errors.exp_type ? ( +

    {errors.exp_type}

    ) : (

     

    )} diff --git a/frontend/src/app/(protected)/experiments/add/page.tsx b/frontend/src/app/(protected)/experiments/add/page.tsx index 4e6fd0b..6565ab4 100644 --- a/frontend/src/app/(protected)/experiments/add/page.tsx +++ b/frontend/src/app/(protected)/experiments/add/page.tsx @@ -29,7 +29,7 @@ export default function NewExperiment() { const { token } = useAuth(); const router = useRouter(); - const [steps, setSteps] = useState(AllSteps[experimentState.methodType]); + const [steps, setSteps] = useState(AllSteps(experimentState.exp_type)); const [isSubmitting, setIsSubmitting] = useState(false); const { toast } = useToast(); @@ -39,9 +39,9 @@ export default function NewExperiment() { }, []); useEffect(() => { - setSteps(AllSteps[experimentState.methodType]); + setSteps(AllSteps(experimentState.exp_type)); setCurrentStep(0); - }, [experimentState.methodType]); + }, [experimentState.exp_type]); const nextStep = useCallback(() => { const currentValidation = stepValidations[currentStep]; diff --git a/frontend/src/app/(protected)/experiments/store/useExperimentStore.ts b/frontend/src/app/(protected)/experiments/store/useExperimentStore.ts index 3c5a02b..13a3747 100644 --- a/frontend/src/app/(protected)/experiments/store/useExperimentStore.ts +++ b/frontend/src/app/(protected)/experiments/store/useExperimentStore.ts @@ -169,9 +169,6 @@ export const useExperimentStore = create()( updateMethodType: (newMethodType: MethodType) => set((state) => { const { experimentState } = state; - if (newMethodType === experimentState.exp_type) - return { experimentState }; - const { reward_type, notifications } = experimentState; let newState: NewExperimentState; @@ -201,7 +198,7 @@ export const useExperimentStore = create()( } else if (newMethodType == "cmab") { newState = { ...experimentState, - methodType: newMethodType, + exp_type: newMethodType, prior_type: "normal", reward_type, notifications, @@ -230,7 +227,7 @@ export const useExperimentStore = create()( } else if (newMethodType == "bayes_ab") { newState = { ...experimentState, - methodType: newMethodType, + exp_type: newMethodType, prior_type: "normal", reward_type, notifications, @@ -295,7 +292,7 @@ export const useExperimentStore = create()( } else if (experimentState.exp_type === "cmab") { newState = { ...experimentState, - prior_type: "normal", + prior_type: newPriorType, arms: experimentState.arms.map(() => ({ ...baseArm, mu_init: 0, From 3d64bf7995fa5278155ed753d402e6208dc2538a Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Tue, 10 Jun 2025 10:04:19 +0300 Subject: [PATCH 64/74] add context page --- .../addCMABContext.tsx => addContext.tsx} | 41 +++++++++++-------- 1 file changed, 23 insertions(+), 18 deletions(-) rename frontend/src/app/(protected)/experiments/add/components/{cmabs/addCMABContext.tsx => addContext.tsx} (90%) diff --git a/frontend/src/app/(protected)/experiments/add/components/cmabs/addCMABContext.tsx b/frontend/src/app/(protected)/experiments/add/components/addContext.tsx similarity index 90% rename from frontend/src/app/(protected)/experiments/add/components/cmabs/addCMABContext.tsx rename to frontend/src/app/(protected)/experiments/add/components/addContext.tsx index 79cf234..8a4b4a4 100644 --- a/frontend/src/app/(protected)/experiments/add/components/cmabs/addCMABContext.tsx +++ b/frontend/src/app/(protected)/experiments/add/components/addContext.tsx @@ -3,17 +3,17 @@ import { Input } from "@/components/ui/input"; import { Textarea } from "@/components/ui/textarea"; import { Label } from "@/components/ui/label"; import { RadioGroup, RadioGroupItem } from "@/components/ui/radio-group"; -import { useExperimentStore } from "../../../store/useExperimentStore"; +import { useExperimentStore } from "../../store/useExperimentStore"; import type { - CMABExperimentState, + ExperimentState, StepComponentProps, ContextType, -} from "../../../types"; +} from "../../types"; import { Plus, Trash } from "lucide-react"; import { DividerWithTitle } from "@/components/Dividers"; import { useCallback, useEffect, useState } from "react"; -export default function AddCMABContext({ onValidate }: StepComponentProps) { +export default function AddContext({ onValidate }: StepComponentProps) { const { experimentState, updateContext, addContext, removeContext } = useExperimentStore(); @@ -23,16 +23,17 @@ export default function AddCMABContext({ onValidate }: StepComponentProps) { const validateForm = useCallback(() => { let isValid = true; - const newErrors = (experimentState as CMABExperimentState).contexts.map( - () => ({ + let newErrors = [{ name: "", description: "", value_type: "" }]; + const contexts = (experimentState as ExperimentState).contexts; + + if (contexts) { + newErrors = contexts.map(() => ({ name: "", description: "", value_type: "", - }) - ); + })); - (experimentState as CMABExperimentState).contexts.forEach( - (context, index) => { + contexts.forEach((context, index) => { if (!context.name.trim()) { newErrors[index].name = "Context name is required"; isValid = false; @@ -47,8 +48,8 @@ export default function AddCMABContext({ onValidate }: StepComponentProps) { newErrors[index].value_type = "Context value type is required"; isValid = false; } - } - ); + }); + } return { isValid, newErrors }; }, [experimentState]); @@ -68,11 +69,11 @@ export default function AddCMABContext({ onValidate }: StepComponentProps) { } }, [validateForm, onValidate, errors]); - return ( + return experimentState.contexts ? (

    - Add CMAB Contexts + Add Contexts

    - {(experimentState as CMABExperimentState).contexts.map( + {experimentState.contexts.map( (context, index) => (
    @@ -233,5 +234,9 @@ export default function AddCMABContext({ onValidate }: StepComponentProps) { )}
    + ) : ( +
    + No contexts available. Please add a context to proceed. +
    ); } From 2365592947b0bd557d006db16777783310c17e73 Mon Sep 17 00:00:00 2001 From: poornimaramesh Date: Tue, 10 Jun 2025 10:25:45 +0300 Subject: [PATCH 65/74] add input arms page --- .../experiments/add/components/addArms.tsx | 388 ++++++++++++++++++ .../add/components/addExperimentSteps.tsx | 120 +++--- 2 files changed, 460 insertions(+), 48 deletions(-) create mode 100644 frontend/src/app/(protected)/experiments/add/components/addArms.tsx diff --git a/frontend/src/app/(protected)/experiments/add/components/addArms.tsx b/frontend/src/app/(protected)/experiments/add/components/addArms.tsx new file mode 100644 index 0000000..48c8e3e --- /dev/null +++ b/frontend/src/app/(protected)/experiments/add/components/addArms.tsx @@ -0,0 +1,388 @@ +import { + useExperimentStore, + isMABExperimentStateBeta, + isBayesianABState +} from "../../store/useExperimentStore"; +import { Button } from "@/components/ui/button"; +import { Input } from "@/components/ui/input"; +import { Textarea } from "@/components/ui/textarea"; +import { Label } from "@/components/ui/label"; +import type { + NewArm, + StepComponentProps, +} from "../../types"; +import { Plus, Trash } from "lucide-react"; +import { DividerWithTitle } from "@/components/Dividers"; +import { useCallback, useEffect, useState, useMemo } from "react"; + +export default function AddMABArms({ onValidate }: StepComponentProps) { + const { experimentState, updateArm, addArm, removeArm } = + useExperimentStore(); + + const [inputValues, setInputValues] = useState>({}); + + const bayesABarms: Record = {1: "Treatment Arm", 2: "Control Arm"}; + + const baseArmDesc = useMemo( + () => ({ + name: "", + description: "", + }), + [] + ); + + const additionalArmErrors = useMemo( + () => + experimentState.prior_type === "beta" + ? { alpha_init: "", beta_init: "" } + : { mu_init: "", sigma_init: "" }, + [experimentState] + ); + + const [errors, setErrors] = useState(() => { + return experimentState.arms.map(() => { + return { ...baseArmDesc, ...additionalArmErrors }; + }); + }); + + const validateForm = useCallback(() => { + let isValid = true; + const newErrors = experimentState.arms.map(() => ({ + ...baseArmDesc, + ...additionalArmErrors, + })); + + experimentState.arms.forEach((arm, index) => { + if (!arm.name.trim()) { + newErrors[index].name = "Arm name is required"; + isValid = false; + } + + if (!arm.description.trim()) { + newErrors[index].description = "Description is required"; + isValid = false; + } + + if (experimentState.prior_type === "beta") { + if ("alpha_init" in arm) { + if (!arm.alpha_init) { + newErrors[index].alpha_init = "Alpha prior is required"; + isValid = false; + } + if (arm.alpha_init && arm.alpha_init <= 0) { + newErrors[index].alpha_init = + "Alpha prior should be greater than 0"; + isValid = false; + } + } + + if ("beta_init" in arm) { + if (!arm.beta_init) { + newErrors[index].beta_init = "Beta prior is required"; + isValid = false; + } + + if (arm.beta_init && arm.beta_init <= 0) { + newErrors[index].beta_init = "Beta prior should be greater than 0"; + isValid = false; + } + } + } else if (experimentState.prior_type === "normal") { + if ("mu_init" in arm && typeof arm.mu_init !== "number") { + newErrors[index].mu_init = "Mean value is required"; + isValid = false; + } + + if ("sigma_init" in arm) { + if (!arm.sigma_init) { + newErrors[index].sigma_init = "Std. deviation is required"; + isValid = false; + } + + if (arm.sigma_init && arm.sigma_init <= 0) { + newErrors[index].sigma_init = + "Std deviation should be greater than 0"; + isValid = false; + } + } + } + }); + return { isValid, newErrors }; + }, [experimentState, baseArmDesc, additionalArmErrors]); + + useEffect(() => { + const { isValid, newErrors } = validateForm(); + if (JSON.stringify(newErrors) !== JSON.stringify(errors)) { + setErrors(newErrors); + onValidate({ + isValid, + errors: newErrors.map((error) => + Object.fromEntries( + Object.entries(error).map(([key, value]) => [key, value ?? ""]) + ) + ), + }); + } + }, [validateForm, onValidate, errors]); + + useEffect(() => { + const newInputValues: Record = {}; + + if (!isMABExperimentStateBeta(experimentState)) { + experimentState.arms.forEach((arm, index) => { + newInputValues[`${index}-mu`] = ( + (arm as NewArm).mu_init || 0 + ).toString(); + }); + } + setInputValues(newInputValues); + }, [experimentState]); + + const handleNumericChange = (index: number, value: string) => { + // Update the local input state for a smooth typing experience + setInputValues((prev) => ({ + ...prev, + [`${index}-mu`]: value, + })); + + if (value !== "" && value !== "-") { + const numValue = Number.parseFloat(value); + if (!isNaN(numValue)) { + updateArm(index, { mu_init: numValue }); + } + } + }; + + return ( +
    +
    +

    Add Experiment Arms

    +
    + + +
    +
    +
    + {experimentState.arms.map((arm, index) => ( +
    + +
    +
    +
    +
    + +
    + + updateArm(index, { name: e.target.value }) + } + /> + {errors[index]?.name ? ( +

    + {errors[index].name} +

    + ) : ( +

     

    + )} +
    +
    +
    +
    +
    + +
    +