diff --git a/backend/app/experiments/dependencies.py b/backend/app/experiments/dependencies.py index bcf15d2..6371408 100644 --- a/backend/app/experiments/dependencies.py +++ b/backend/app/experiments/dependencies.py @@ -18,7 +18,7 @@ from .schemas import ( ArmPriors, ArmResponse, - ExperimentSample, + ExperimentResponse, ExperimentsEnum, NotificationsResponse, ObservationType, @@ -30,7 +30,7 @@ async def experiments_db_to_schema( experiments_db: list[ExperimentDB], asession: AsyncSession, -) -> list[ExperimentSample]: +) -> list[ExperimentResponse]: """ Convert a list of ExperimentDB objects to a list of ExperimentResponse schemas. """ @@ -47,7 +47,7 @@ async def experiments_db_to_schema( ) ] all_experiments.append( - ExperimentSample.model_validate( + ExperimentResponse.model_validate( { **exp_dict, "notifications": [ @@ -269,7 +269,7 @@ async def update_arm_parameters( treatments: Union[list[float], None], ) -> None: """Update the arm parameters based on the reward type and outcome""" - experiment_data = ExperimentSample.model_validate(experiment.to_dict()) + experiment_data = ExperimentResponse.model_validate(experiment.to_dict()) if experiment_data.reward_type == RewardLikelihood.BERNOULLI: Outcome(rewards[0]) # Check if reward is 0 or 1 params = update_arm( diff --git a/backend/app/experiments/models.py b/backend/app/experiments/models.py index 3438b1d..f01ce94 100644 --- a/backend/app/experiments/models.py +++ b/backend/app/experiments/models.py @@ -262,6 +262,14 @@ class DrawDB(Base): context_val: Mapped[Optional[list[float]]] = mapped_column( ARRAY(Float), nullable=True ) + current_alpha: Mapped[Optional[float]] = mapped_column(Float, nullable=True) + current_beta: Mapped[Optional[float]] = mapped_column(Float, nullable=True) + current_mu: Mapped[Optional[list[float]]] = mapped_column( + ARRAY(Float), nullable=True + ) + current_covariance: Mapped[Optional[list[float]]] = mapped_column( + ARRAY(Float), nullable=True + ) # Relationships arm: Mapped[ArmDB] = relationship("ArmDB", back_populates="draws", lazy="joined") @@ -708,6 +716,10 @@ async def save_observation_to_db( draw.observed_datetime_utc = datetime.now(timezone.utc) draw.observation_type = observation_type draw.reward = reward + draw.current_alpha = draw.arm.alpha + draw.current_beta = draw.arm.beta + draw.current_mu = draw.arm.mu + draw.current_covariance = draw.arm.covariance await asession.commit() await asession.refresh(draw) diff --git a/backend/app/experiments/routers.py b/backend/app/experiments/routers.py index feaba03..dc7eed0 100644 --- a/backend/app/experiments/routers.py +++ b/backend/app/experiments/routers.py @@ -42,11 +42,13 @@ ContextType, DrawResponse, Experiment, - ExperimentSample, + ExperimentResponse, ExperimentsEnum, ObservationType, Outcome, + PlottingData, ) +from .visualization_utils import get_required_plotting_data router = APIRouter(prefix="/experiment", tags=["Experiments"]) @@ -54,12 +56,12 @@ # --- POST experiments routers --- -@router.post("/", response_model=ExperimentSample) +@router.post("/", response_model=ExperimentResponse) async def create_experiment( experiment: Experiment, user_db: Annotated[UserDB, Depends(require_admin_role)], asession: AsyncSession = Depends(get_async_session), -) -> ExperimentSample: +) -> ExperimentResponse: """ Create a new experiment in the current user's workspace. """ @@ -87,15 +89,15 @@ async def create_experiment( experiment_dict = experiment_db.to_dict() experiment_dict["notifications"] = [n.to_dict() for n in notifications] - return ExperimentSample.model_validate(experiment_dict) + return ExperimentResponse.model_validate(experiment_dict) # -- GET experiment routers --- -@router.get("/", response_model=list[ExperimentSample]) +@router.get("/", response_model=list[ExperimentResponse]) async def get_all_experiments( user_db: Annotated[UserDB, Depends(get_verified_user)], asession: AsyncSession = Depends(get_async_session), -) -> list[ExperimentSample]: +) -> list[ExperimentResponse]: """ Retrieve all experiments for the current user's workspace. """ @@ -119,12 +121,12 @@ async def get_all_experiments( return all_experiments -@router.get("/type/{experiment_type}", response_model=list[ExperimentSample]) +@router.get("/type/{experiment_type}", response_model=list[ExperimentResponse]) async def get_all_experiments_by_type( experiment_type: ExperimentsEnum, user_db: Annotated[UserDB, Depends(get_verified_user)], asession: AsyncSession = Depends(get_async_session), -) -> list[ExperimentSample]: +) -> list[ExperimentResponse]: """ Retrieve all experiments for the current user's workspace. """ @@ -149,12 +151,12 @@ async def get_all_experiments_by_type( return all_experiments -@router.get("/id/{experiment_id}", response_model=ExperimentSample) +@router.get("/id/{experiment_id}", response_model=ExperimentResponse) async def get_experiment_by_id( experiment_id: int, user_db: Annotated[UserDB, Depends(get_verified_user)], asession: AsyncSession = Depends(get_async_session), -) -> ExperimentSample: +) -> ExperimentResponse: """ Retrieve a specific experiment by ID for the current user's workspace. """ @@ -330,7 +332,7 @@ async def draw_experiment_arm( ) # -- Perform the draw --- - experiment_data = ExperimentSample.model_validate(experiment.to_dict()) + experiment_data = ExperimentResponse.model_validate(experiment.to_dict()) # Validate contexts input if contexts: @@ -476,7 +478,77 @@ async def get_rewards( "arm": [arm for arm in experiment.arms if arm.arm_id == draw.arm_id][0], "reward": draw.reward, "context_val": draw.context_val, + "current_alpha": draw.current_alpha, + "current_beta": draw.current_beta, + "current_mu": draw.current_mu, + "current_covariance": draw.current_covariance, } ) for draw in draws ] + + +@router.get("/{experiment_id}/plotting", response_model=PlottingData) +async def get_plotting_data( + experiment_id: int, + workspace_db: WorkspaceDB = Depends(authenticate_workspace_key), + asession: AsyncSession = Depends(get_async_session), +) -> PlottingData: + """ + Retrieve the data required for plotting. + """ + experiment = await get_experiment_by_id_from_db( + workspace_id=workspace_db.workspace_id, + experiment_id=experiment_id, + asession=asession, + ) + + if not experiment: + raise HTTPException( + status_code=404, detail=f"Experiment with id {experiment_id} not found" + ) + draws = await get_draws_by_experiment_id( + experiment_id=experiment_id, asession=asession + ) + + if not draws: + raise HTTPException( + status_code=404, + detail=f"No draws found for experiment with id {experiment_id}", + ) + + experiment_data = ExperimentResponse.model_validate(experiment.to_dict()) + draw_data = [ + DrawResponse.model_validate( + { + "draw_id": draw.draw_id, + "draw_datetime_utc": str(draw.draw_datetime_utc), + "observed_datetime_utc": str(draw.observed_datetime_utc), + "arm": [arm for arm in experiment.arms if arm.arm_id == draw.arm_id][0], + "reward": draw.reward, + "context_val": draw.context_val, + "current_alpha": draw.current_alpha, + "current_beta": draw.current_beta, + "current_mu": draw.current_mu, + "current_covariance": draw.current_covariance, + } + ) + for draw in draws + ] + try: + plotting_data = get_required_plotting_data( + experiment=experiment_data, draws=draw_data + ) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error retrieving plotting data: {str(e)}", + ) from e + + return PlottingData( + prior_samples=plotting_data["prior_samples"], + posterior_samples=plotting_data["posterior_samples"], + volumes=plotting_data["volumes"], + posterior_means=plotting_data["posterior_means"], + posterior_stds=plotting_data["posterior_stds"], + ) diff --git a/backend/app/experiments/sampling_utils.py b/backend/app/experiments/sampling_utils.py index 7ae8eca..7f356fe 100644 --- a/backend/app/experiments/sampling_utils.py +++ b/backend/app/experiments/sampling_utils.py @@ -7,7 +7,7 @@ from .schemas import ( ArmPriors, ContextLinkFunctions, - ExperimentSample, + ExperimentResponse, ExperimentsEnum, Outcome, RewardLikelihood, @@ -167,7 +167,7 @@ def objective(theta: np.ndarray) -> float: # ------------- Import functions ---------------- # --- Choose arm function --- def choose_arm( - experiment: ExperimentSample, context: Optional[Union[list, np.ndarray, None]] + experiment: ExperimentResponse, context: Optional[Union[list, np.ndarray, None]] ) -> int: """ Choose arm based on posterior using Thompson Sampling. @@ -212,7 +212,7 @@ def choose_arm( # --- Update arm parameters --- def update_arm( - experiment: ExperimentSample, + experiment: ExperimentResponse, rewards: list[float], arm_to_update: Optional[int] = None, context: Optional[Union[list, np.ndarray, None]] = None, diff --git a/backend/app/experiments/schemas.py b/backend/app/experiments/schemas.py index ffd7e6e..36307e4 100644 --- a/backend/app/experiments/schemas.py +++ b/backend/app/experiments/schemas.py @@ -333,6 +333,22 @@ class DrawResponse(BaseModel): description="Context values associated with the draw", default=None, ) + current_alpha: Optional[Union[float, None]] = Field( + description="Current alpha value of the arm", + default=None, + ) + current_beta: Optional[Union[float, None]] = Field( + description="Current beta value of the arm", + default=None, + ) + current_mu: Optional[List[Union[float, None]]] = Field( + description="Current mean value of the arm", + default=None, + ) + current_covariance: Optional[List[List[Union[float, None]]]] = Field( + description="Current covariance matrix of the arm", + default=None, + ) arm: ArmResponse client: Optional[Client] = None @@ -516,33 +532,30 @@ def check_contexts(self) -> Self: class ExperimentResponse(ExperimentBase): """ - Pydantic model for a response for experiment creation + Pydantic model for experiments for drawing and updating arms. """ experiment_id: int n_trials: int last_trial_datetime_utc: Optional[str] = None + observation_type: ObservationType = ObservationType.USER arms: list[ArmResponse] - notifications: list[NotificationsResponse] contexts: Optional[list[ContextResponse]] = None clients: Optional[list[Client]] = None model_config = ConfigDict(from_attributes=True) -class ExperimentSample(ExperimentBase): +class PlottingData(BaseModel): """ - Pydantic model for experiments for drawing and updating arms. + Pydantic model for the data required for plotting. """ - experiment_id: int - n_trials: int - last_trial_datetime_utc: Optional[str] = None - observation_type: ObservationType = ObservationType.USER - - arms: list[ArmResponse] - contexts: Optional[list[ContextResponse]] = None - clients: Optional[list[Client]] = None + posterior_means: list + posterior_stds: list + volumes: list + prior_samples: list + posterior_samples: list model_config = ConfigDict(from_attributes=True) diff --git a/backend/app/experiments/visualization_utils.py b/backend/app/experiments/visualization_utils.py new file mode 100644 index 0000000..d46309d --- /dev/null +++ b/backend/app/experiments/visualization_utils.py @@ -0,0 +1,200 @@ +import numpy as np + +from .schemas import ArmPriors, DrawResponse, ExperimentResponse, ExperimentsEnum + + +def get_experiment_params(experiment: ExperimentResponse) -> tuple[list, list]: + """ + Extracts and returns the parameters of the experiment by generating samples from + prior and posterior distributions. + + Args: + experiment (ExperimentResponse): Experiment object containing configuration + and parameters for arms including prior and posterior distribution parameters + + Returns: + tuple[list, list]: Two lists containing: + - prior_samples: List of samples drawn from prior distributions for each arm + - posterior_samples: List of samples drawn from posterior distributions + for each arm + + Raises: + ValueError: If the experiment type or prior type is not supported + (currently supports Beta and Normal priors except for CMAB experiments) + + Notes: + - For Beta priors: Uses alpha and beta parameters to generate samples + - For Normal priors: Uses mean (mu) and standard deviation (sigma) to + generate samples + - Each arm generates 1000 samples from both prior and posterior distributions + - We can use these samples to make boxplots visualizing the + distributions of arms + """ + prior_samples = [] + posterior_samples = [] + try: + if experiment.prior_type == ArmPriors.BETA: + prior_samples = [ + np.random.beta( + float(arm.alpha_init) if arm.alpha_init is not None else 1.0, + float(arm.beta_init) if arm.beta_init is not None else 1.0, + 1000, + ).tolist() + for arm in experiment.arms + ] + posterior_samples = [ + np.random.beta( + float(arm.alpha) if arm.alpha is not None else 1.0, + float(arm.beta) if arm.beta is not None else 1.0, + 1000, + ).tolist() + for arm in experiment.arms + ] + elif ( + experiment.prior_type == ArmPriors.NORMAL + and experiment.exp_type != ExperimentsEnum.CMAB + ): + prior_samples = [ + np.random.normal( + loc=float(arm.mu_init) if arm.mu_init is not None else 0.0, + scale=float(arm.sigma_init) if arm.sigma_init is not None else 1.0, + size=1000, + ).tolist() + for arm in experiment.arms + ] + posterior_samples = [ + np.random.normal( + loc=(float(arm.mu[0]) if arm.mu and arm.mu[0] is not None else 0.0), + scale=( + float(np.array(arm.covariance).ravel()[0]) + if arm.covariance + and np.array(arm.covariance).ravel()[0] is not None + else 1.0 + ), + size=1000, + ).tolist() + for arm in experiment.arms + ] + else: + raise ValueError("Unsupported experiment type or prior type.") + except Exception as e: + raise ValueError(f"Error generating samples: {e}") from e + return prior_samples, posterior_samples + + +def get_posteriors_over_time( + experiment: ExperimentResponse, draws: list[DrawResponse] +) -> tuple[list, list]: + """ + Extracts and returns the posterior samples (means and standard deviations) over + time for each arm in an experiment. + + Args: + experiment (ExperimentResponse): Experiment data containing arms and prior + type information + draws (list[DrawResponse]): List of draw responses containing + posterior distribution parameters + Returns: + tuple[list, list]: Two lists containing: + - means: Posterior mean estimates for each arm over time + - stds: Posterior standard deviation estimates for each arm over time + Raises: + NotImplementedError: If the prior type is not supported (currently + supports Beta and Normal prior distributions, but NOT for CMAB experiments) + Notes: + The function processes draws in reverse chronological order, updating the + mean and std for the drawn arm and maintaining previous values for arms + that are not drawn. + Supports both Beta and Normal prior distributions but NOT for CMAB experiments. + """ + arm_id_to_index = {arm.arm_id: i for i, arm in enumerate(experiment.arms)} + + means = np.zeros((len(experiment.arms), len(draws))) + stds = np.zeros((len(experiment.arms), len(draws))) + for i, draw in enumerate(draws[::-1]): + arm_index = arm_id_to_index[draw.arm.arm_id] + not_arm_index = list(arm_id_to_index.values()) + not_arm_index.remove(arm_index) + + if experiment.prior_type == ArmPriors.BETA: + assert ( + draw.current_alpha and draw.current_beta + ), "current_alpha and current_beta must be provided for Beta prior" + means[arm_index, i] = draw.current_alpha / ( + draw.current_alpha + draw.current_beta + ) + stds[arm_index, i] = np.sqrt( + (draw.current_alpha * draw.current_beta) + / ( + (draw.current_alpha + draw.current_beta) ** 2 + * (draw.current_alpha + draw.current_beta + 1) + ) + ) + elif ( + experiment.prior_type == ArmPriors.NORMAL + and experiment.exp_type != ExperimentsEnum.CMAB + ): + assert ( + draw.current_mu and draw.current_covariance + ), "current_mu and current_covariance must be provided for Normal prior" + means[arm_index, i] = np.array(draw.current_mu).ravel()[0] + stds[arm_index, i] = np.sqrt(np.array(draw.current_covariance).ravel())[0] + + else: + raise ValueError( + "Unsupported prior type or experiment type for posterior" + + "calculation" + ) + + for j in not_arm_index: + means[j, i] = means[j, i - 1] + stds[j, i] = stds[j, i - 1] + + return means.tolist(), stds.tolist() + + +def get_volume_assigned_over_time( + experiment: ExperimentResponse, draws: list[DrawResponse] +) -> list: + """ + Extracts and returns the volume assigned to each arm over time. + + Args: + experiment (ExperimentResponse): Experiment data containing arms and prior + type information + draws (list[DrawResponse]): List of draw responses containing arm assignments + Returns: + list: A 2D list where each sublist contains the volume assigned to each arm + at each time step, with the same order as the arms in the experiment. + Notes: + The function processes draws in reverse chronological order, counting the + number of times each arm was assigned and calculating the cumulative volume + assigned to each arm at each time step. + """ + arm_id_to_index = {arm.arm_id: i for i, arm in enumerate(experiment.arms)} + volumes = np.zeros((len(experiment.arms), len(draws))) + arms_assigned = [arm_id_to_index[draw.arm.arm_id] for draw in draws[::-1]] + + for i in arm_id_to_index.values(): + arms_assigned_count = np.cumsum(np.array(arms_assigned) == i) + volumes[i, :] = arms_assigned_count / (np.arange(len(draws)) + 1) + return volumes.tolist() + + +def get_required_plotting_data( + experiment: ExperimentResponse, draws: list[DrawResponse] +) -> dict: + """ + Extracts and returns the data required for plotting. + """ + prior_samples, posterior_samples = get_experiment_params(experiment) + posterior_means, posterior_stds = get_posteriors_over_time(experiment, draws) + volumes = get_volume_assigned_over_time(experiment, draws) + + return { + "prior_samples": prior_samples, + "posterior_samples": posterior_samples, + "posterior_means": posterior_means, + "posterior_stds": posterior_stds, + "volumes": volumes, + } diff --git a/backend/migrations/versions/45b9483ee392_fix_messages_foreign_key_constraint.py b/backend/migrations/versions/45b9483ee392_fix_messages_foreign_key_constraint.py index cdfcd4c..8b6de33 100644 --- a/backend/migrations/versions/45b9483ee392_fix_messages_foreign_key_constraint.py +++ b/backend/migrations/versions/45b9483ee392_fix_messages_foreign_key_constraint.py @@ -34,7 +34,7 @@ def downgrade() -> None: op.create_foreign_key( op.f("event_messages_experiment_id_fkey"), "event_messages", - "experiments_base", + "experiments", ["experiment_id"], ["experiment_id"], ) diff --git a/backend/migrations/versions/6101ba814d91_fresh_start.py b/backend/migrations/versions/6101ba814d91_fresh_start.py index d246310..27509e6 100644 --- a/backend/migrations/versions/6101ba814d91_fresh_start.py +++ b/backend/migrations/versions/6101ba814d91_fresh_start.py @@ -122,38 +122,7 @@ def upgrade() -> None: ), sa.PrimaryKeyConstraint("experiment_id"), ) - op.create_table( - "experiments_base", - sa.Column("experiment_id", sa.Integer(), nullable=False), - sa.Column("name", sa.String(length=150), nullable=False), - sa.Column("description", sa.String(length=500), nullable=False), - sa.Column("sticky_assignment", sa.Boolean(), nullable=False), - sa.Column("auto_fail", sa.Boolean(), nullable=False), - sa.Column("auto_fail_value", sa.Integer(), nullable=True), - sa.Column( - "auto_fail_unit", - sa.Enum("DAYS", "HOURS", name="autofailunittype"), - nullable=True, - ), - sa.Column("user_id", sa.Integer(), nullable=False), - sa.Column("workspace_id", sa.Integer(), nullable=False), - sa.Column("is_active", sa.Boolean(), nullable=False), - sa.Column("exp_type", sa.String(length=50), nullable=False), - sa.Column("prior_type", sa.String(length=50), nullable=False), - sa.Column("reward_type", sa.String(length=50), nullable=False), - sa.Column("created_datetime_utc", sa.DateTime(timezone=True), nullable=False), - sa.Column("n_trials", sa.Integer(), nullable=False), - sa.Column("last_trial_datetime_utc", sa.DateTime(timezone=True), nullable=True), - sa.ForeignKeyConstraint( - ["user_id"], - ["users.user_id"], - ), - sa.ForeignKeyConstraint( - ["workspace_id"], - ["workspace.workspace_id"], - ), - sa.PrimaryKeyConstraint("experiment_id"), - ) + op.create_table( "pending_invitations", sa.Column("invitation_id", sa.Integer(), nullable=False), @@ -225,25 +194,6 @@ def upgrade() -> None: ), sa.PrimaryKeyConstraint("arm_id"), ) - op.create_table( - "arms_base", - sa.Column("arm_id", sa.Integer(), nullable=False), - sa.Column("experiment_id", sa.Integer(), nullable=False), - sa.Column("user_id", sa.Integer(), nullable=False), - sa.Column("name", sa.String(length=150), nullable=False), - sa.Column("description", sa.String(length=500), nullable=False), - sa.Column("arm_type", sa.String(length=50), nullable=False), - sa.Column("n_outcomes", sa.Integer(), nullable=False), - sa.ForeignKeyConstraint( - ["experiment_id"], - ["experiments_base.experiment_id"], - ), - sa.ForeignKeyConstraint( - ["user_id"], - ["users.user_id"], - ), - sa.PrimaryKeyConstraint("arm_id"), - ) op.create_table( "clients", sa.Column("client_id", sa.String(), nullable=False), @@ -283,7 +233,7 @@ def upgrade() -> None: sa.Column("experiment_id", sa.Integer(), nullable=False), sa.ForeignKeyConstraint( ["experiment_id"], - ["experiments_base.experiment_id"], + ["experiments.experiment_id"], ), sa.ForeignKeyConstraint( ["message_id"], ["messages.message_id"], ondelete="CASCADE" @@ -322,33 +272,6 @@ def upgrade() -> None: ), sa.PrimaryKeyConstraint("notification_id"), ) - op.create_table( - "notifications_db", - sa.Column("notification_id", sa.Integer(), nullable=False), - sa.Column("experiment_id", sa.Integer(), nullable=False), - sa.Column("user_id", sa.Integer(), nullable=False), - sa.Column( - "notification_type", - sa.Enum( - "DAYS_ELAPSED", - "TRIALS_COMPLETED", - "PERCENTAGE_BETTER", - name="eventtype", - ), - nullable=False, - ), - sa.Column("notification_value", sa.Integer(), nullable=False), - sa.Column("is_active", sa.Boolean(), nullable=False), - sa.ForeignKeyConstraint( - ["experiment_id"], - ["experiments_base.experiment_id"], - ), - sa.ForeignKeyConstraint( - ["user_id"], - ["users.user_id"], - ), - sa.PrimaryKeyConstraint("notification_id"), - ) op.create_table( "draws", sa.Column("draw_id", sa.String(), nullable=False), @@ -383,53 +306,19 @@ def upgrade() -> None: ), sa.PrimaryKeyConstraint("draw_id"), ) - op.create_table( - "draws_base", - sa.Column("draw_id", sa.String(), nullable=False), - sa.Column("client_id", sa.String(), nullable=True), - sa.Column("arm_id", sa.Integer(), nullable=False), - sa.Column("experiment_id", sa.Integer(), nullable=False), - sa.Column("user_id", sa.Integer(), nullable=False), - sa.Column("draw_datetime_utc", sa.DateTime(timezone=True), nullable=False), - sa.Column("observed_datetime_utc", sa.DateTime(timezone=True), nullable=True), - sa.Column( - "observation_type", - sa.Enum("USER", "AUTO", name="observationtype"), - nullable=True, - ), - sa.Column("draw_type", sa.String(length=50), nullable=False), - sa.Column("reward", sa.Float(), nullable=True), - sa.ForeignKeyConstraint( - ["arm_id"], - ["arms_base.arm_id"], - ), - sa.ForeignKeyConstraint( - ["experiment_id"], - ["experiments_base.experiment_id"], - ), - sa.ForeignKeyConstraint( - ["user_id"], - ["users.user_id"], - ), - sa.PrimaryKeyConstraint("draw_id"), - ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_table("draws_base") op.drop_table("draws") - op.drop_table("notifications_db") op.drop_table("notifications") op.drop_table("event_messages") op.drop_table("context") op.drop_table("clients") - op.drop_table("arms_base") op.drop_table("arms") op.drop_table("user_workspace") op.drop_table("pending_invitations") - op.drop_table("experiments_base") op.drop_table("experiments") op.drop_table("api_key_rotation_history") op.drop_table("workspace") diff --git a/backend/migrations/versions/96b30c7be83c_track_current_parameters_with_each_.py b/backend/migrations/versions/96b30c7be83c_track_current_parameters_with_each_.py new file mode 100644 index 0000000..8235bd1 --- /dev/null +++ b/backend/migrations/versions/96b30c7be83c_track_current_parameters_with_each_.py @@ -0,0 +1,43 @@ +"""track current parameters with each reward + +Revision ID: 96b30c7be83c +Revises: 45b9483ee392 +Create Date: 2025-06-25 18:59:02.679588 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "96b30c7be83c" +down_revision: Union[str, None] = "45b9483ee392" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("draws", sa.Column("current_alpha", sa.Float(), nullable=True)) + op.add_column("draws", sa.Column("current_beta", sa.Float(), nullable=True)) + op.add_column( + "draws", sa.Column("current_mu", postgresql.ARRAY(sa.Float()), nullable=True) + ) + op.add_column( + "draws", + sa.Column("current_covariance", postgresql.ARRAY(sa.Float()), nullable=True), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("draws", "current_covariance") + op.drop_column("draws", "current_mu") + op.drop_column("draws", "current_beta") + op.drop_column("draws", "current_alpha") + + # ### end Alembic commands ### diff --git a/backend/requirements.txt b/backend/requirements.txt index b762ddb..ca03ad4 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -15,3 +15,4 @@ sqlalchemy[asyncio]==2.0.20 uvicorn==0.23.2 boto3==1.37.25 pydantic[email]==2.11.3 +matplotlib==3.10.3 diff --git a/backend/tests/test_experiments.py b/backend/tests/test_experiments.py index f2bdd7c..070aa16 100644 --- a/backend/tests/test_experiments.py +++ b/backend/tests/test_experiments.py @@ -535,6 +535,54 @@ def test_get_rewards( assert response.status_code == 200 assert len(response.json()) == n_draws + @mark.parametrize( + "create_experiment_payload, expected_response", + [ + ("base_beta_binom", 200), + ("base_normal", 200), + ("bayes_ab_normal_binom", 200), + ("cmab_normal", 500), + ], + indirect=["create_experiment_payload"], + ) + def test_get_plotting_data( + self, + client: TestClient, + create_experiments: list, + create_experiment_payload: dict, + expected_response: int, + workspace_api_key: str, + ) -> None: + id = create_experiments[0]["experiment_id"] + exp_type = create_experiments[0]["exp_type"] + contexts = None + if exp_type == "cmab": + contexts = [ + {"context_id": context["context_id"], "context_value": 1} + for context in create_experiments[0]["contexts"] + ] + + for _ in range(5): + response = client.put( + f"/experiment/{id}/draw", + headers={"Authorization": f"Bearer {workspace_api_key}"}, + json=contexts, + ) + draw_id = response.json()["draw_id"] + # put outcomes + response = client.put( + f"/experiment/{id}/{draw_id}/1", + headers={"Authorization": f"Bearer {workspace_api_key}"}, + ) + + response = client.get( + f"/experiment/{id}/plotting", + headers={"Authorization": f"Bearer {workspace_api_key}"}, + ) + print(response.status_code) + + assert response.status_code == expected_response + class TestNotifications: @fixture()