Skip to content

Add backend support for visualizations #90

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 84 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
3f0b611
Add user verification and password reset
Jayprakash-SE Apr 2, 2025
3e86241
Added env variables for test
Jayprakash-SE Apr 2, 2025
b183861
Fix redirect issue
Jayprakash-SE Apr 3, 2025
a2923ef
Merge remote-tracking branch 'origin' into workspaces
Jayprakash-SE Apr 15, 2025
c79830b
Adding the workspace feature
Jayprakash-SE Apr 21, 2025
219aea7
File formating
Jayprakash-SE Apr 21, 2025
88cc7aa
Merge branch 'main' into workspaces
Jayprakash-SE Apr 30, 2025
43a2c38
Added workspace removal and user list
Jayprakash-SE Apr 30, 2025
64768d6
New frontend
Jayprakash-SE May 1, 2025
700fb36
Remove old sidebar
Jayprakash-SE May 1, 2025
b7ff0b5
Fix email link
Jayprakash-SE May 1, 2025
609f9d8
Removed the create workspace from switcher
Jayprakash-SE May 1, 2025
b508c4b
Add API key history
Jayprakash-SE May 4, 2025
fcbc483
Merge remote-tracking branch 'origin' into workspaces
Jayprakash-SE May 4, 2025
daf1f9f
Fix the tests
Jayprakash-SE May 4, 2025
a84527b
Fix npm build errors
Jayprakash-SE May 4, 2025
4544023
Changed workspace UI
Jayprakash-SE May 5, 2025
6488f97
Fix errors
Jayprakash-SE May 5, 2025
f629487
Fix circular import
Jayprakash-SE May 5, 2025
ba23c1c
Fix the mypy error
Jayprakash-SE May 6, 2025
3f5c0ec
Fix Ruff errors
Jayprakash-SE May 7, 2025
6026bb5
Fix formating
Jayprakash-SE May 7, 2025
68fbc7c
Squashed commit of the following:
poornimaramesh May 9, 2025
424cd4d
Addressed comments
Jayprakash-SE May 14, 2025
f9a5471
Merge remote-tracking branch 'origin' into workspaces
Jayprakash-SE May 14, 2025
64ae3ad
Minor fix
Jayprakash-SE May 14, 2025
02f0c73
Fixed import
Jayprakash-SE May 14, 2025
58aebfb
Removed user filter
Jayprakash-SE May 14, 2025
aec6484
Merge branch 'main' into refactor
poornimaramesh May 14, 2025
ebb7cd3
Addressed comments
Jayprakash-SE May 21, 2025
2cbf107
Fix the removal of user from workspace
Jayprakash-SE May 21, 2025
ccf227b
Fix the ts errors
Jayprakash-SE May 21, 2025
2783946
WIP: models and schemas
poornimaramesh May 26, 2025
8b52b1c
WIP: models and schemas
poornimaramesh May 26, 2025
c0aa9c4
Squashed commit of the following:
poornimaramesh May 26, 2025
8bb9ecc
Merge remote-tracking branch 'origin/workspaces' into refactor
poornimaramesh May 26, 2025
b062c7c
fix tests
poornimaramesh May 26, 2025
196c3bb
working version of models, schemas and routers for creating experiments
poornimaramesh May 27, 2025
1cff241
working version of get all mabs endpoint
poornimaramesh May 27, 2025
797b050
add epxeriment type get router
poornimaramesh May 27, 2025
1f3beb3
add get routers for exp by id
poornimaramesh May 27, 2025
2947b90
debugging endpoints
poornimaramesh May 27, 2025
3e8aa12
add bulk delete router
poornimaramesh May 27, 2025
566ca36
debugging, WIP sampling utils
poornimaramesh May 28, 2025
483d481
fix linting
poornimaramesh May 28, 2025
d9dacda
WIP: draw arm + update arm routers
poornimaramesh May 29, 2025
4ab45c5
fresh start migrations
poornimaramesh May 29, 2025
94373ce
add choose arm and update arm routers + functions
poornimaramesh May 30, 2025
3b273c4
fix linting
poornimaramesh May 30, 2025
156cbe3
debug routers for beta-binary mab
poornimaramesh May 30, 2025
ec2b7bd
debug normal/real-valued experiments
poornimaramesh May 30, 2025
b9c583d
debug Bayes AB beta-binom
poornimaramesh Jun 3, 2025
7050907
debug mabs
poornimaramesh Jun 3, 2025
0f86e7c
debug cmabs
poornimaramesh Jun 3, 2025
6dc932d
resolve merge conflicts with main
poornimaramesh Jun 3, 2025
fb2904b
update autofail
poornimaramesh Jun 3, 2025
36a0db5
delete old routers and migrations
poornimaramesh Jun 3, 2025
44ba5dc
Merge branch 'main' into refactor
poornimaramesh Jun 4, 2025
40e6c77
debugging
poornimaramesh Jun 4, 2025
6ac9a31
fix messages test
poornimaramesh Jun 5, 2025
741245b
fix notifications and tests
poornimaramesh Jun 5, 2025
8eea2db
debug autofail and fix corresponding tests
poornimaramesh Jun 5, 2025
2f63c14
experiment tests for mabs
poornimaramesh Jun 6, 2025
09aecdc
update tests for Bayes AB experiments
poornimaramesh Jun 6, 2025
8701324
add tests for CMAB
poornimaramesh Jun 6, 2025
4a48364
delete old tests
poornimaramesh Jun 6, 2025
4845a44
merge changes from refactor
poornimaramesh Jun 6, 2025
d71746a
merge changes from tests
poornimaramesh Jun 6, 2025
0d5a574
update display experiments page
poornimaramesh Jun 9, 2025
c8bb88d
fix experiment viz
poornimaramesh Jun 9, 2025
4e0078e
debug prior-reward config
poornimaramesh Jun 9, 2025
3d64bf7
add context page
poornimaramesh Jun 10, 2025
2365592
add input arms page
poornimaramesh Jun 10, 2025
1bfb017
add notifications
poornimaramesh Jun 10, 2025
8a6f5c6
debugging
poornimaramesh Jun 10, 2025
b80826d
clean up
poornimaramesh Jun 10, 2025
66f639d
merge changes from base refactor
poornimaramesh Jun 19, 2025
cbbf55d
fix tests
poornimaramesh Jun 19, 2025
82f3355
merge chnges from other refactor branches
poornimaramesh Jun 19, 2025
9da3f9c
Merge branch 'refactor-frontend' into visualization
poornimaramesh Jun 25, 2025
f9310c5
update migration files, requirements, and log outcomes corectly
poornimaramesh Jun 25, 2025
6034fd0
fix tests
poornimaramesh Jun 26, 2025
a3a7ef1
fix mypy errors
poornimaramesh Jun 26, 2025
e02fdd7
add test for plotting data router
poornimaramesh Jun 26, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions backend/app/experiments/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .schemas import (
ArmPriors,
ArmResponse,
ExperimentSample,
ExperimentResponse,
ExperimentsEnum,
NotificationsResponse,
ObservationType,
Expand All @@ -30,7 +30,7 @@
async def experiments_db_to_schema(
experiments_db: list[ExperimentDB],
asession: AsyncSession,
) -> list[ExperimentSample]:
) -> list[ExperimentResponse]:
"""
Convert a list of ExperimentDB objects to a list of ExperimentResponse schemas.
"""
Expand All @@ -47,7 +47,7 @@ async def experiments_db_to_schema(
)
]
all_experiments.append(
ExperimentSample.model_validate(
ExperimentResponse.model_validate(
{
**exp_dict,
"notifications": [
Expand Down Expand Up @@ -269,7 +269,7 @@ async def update_arm_parameters(
treatments: Union[list[float], None],
) -> None:
"""Update the arm parameters based on the reward type and outcome"""
experiment_data = ExperimentSample.model_validate(experiment.to_dict())
experiment_data = ExperimentResponse.model_validate(experiment.to_dict())
if experiment_data.reward_type == RewardLikelihood.BERNOULLI:
Outcome(rewards[0]) # Check if reward is 0 or 1
params = update_arm(
Expand Down
12 changes: 12 additions & 0 deletions backend/app/experiments/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,14 @@ class DrawDB(Base):
context_val: Mapped[Optional[list[float]]] = mapped_column(
ARRAY(Float), nullable=True
)
current_alpha: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
current_beta: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
current_mu: Mapped[Optional[list[float]]] = mapped_column(
ARRAY(Float), nullable=True
)
current_covariance: Mapped[Optional[list[float]]] = mapped_column(
ARRAY(Float), nullable=True
)

# Relationships
arm: Mapped[ArmDB] = relationship("ArmDB", back_populates="draws", lazy="joined")
Expand Down Expand Up @@ -708,6 +716,10 @@ async def save_observation_to_db(
draw.observed_datetime_utc = datetime.now(timezone.utc)
draw.observation_type = observation_type
draw.reward = reward
draw.current_alpha = draw.arm.alpha
draw.current_beta = draw.arm.beta
draw.current_mu = draw.arm.mu
draw.current_covariance = draw.arm.covariance

await asession.commit()
await asession.refresh(draw)
Expand Down
94 changes: 83 additions & 11 deletions backend/app/experiments/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,24 +42,26 @@
ContextType,
DrawResponse,
Experiment,
ExperimentSample,
ExperimentResponse,
ExperimentsEnum,
ObservationType,
Outcome,
PlottingData,
)
from .visualization_utils import get_required_plotting_data

router = APIRouter(prefix="/experiment", tags=["Experiments"])

logger = setup_logger(__name__)


# --- POST experiments routers ---
@router.post("/", response_model=ExperimentSample)
@router.post("/", response_model=ExperimentResponse)
async def create_experiment(
experiment: Experiment,
user_db: Annotated[UserDB, Depends(require_admin_role)],
asession: AsyncSession = Depends(get_async_session),
) -> ExperimentSample:
) -> ExperimentResponse:
"""
Create a new experiment in the current user's workspace.
"""
Expand Down Expand Up @@ -87,15 +89,15 @@ async def create_experiment(

experiment_dict = experiment_db.to_dict()
experiment_dict["notifications"] = [n.to_dict() for n in notifications]
return ExperimentSample.model_validate(experiment_dict)
return ExperimentResponse.model_validate(experiment_dict)


# -- GET experiment routers ---
@router.get("/", response_model=list[ExperimentSample])
@router.get("/", response_model=list[ExperimentResponse])
async def get_all_experiments(
user_db: Annotated[UserDB, Depends(get_verified_user)],
asession: AsyncSession = Depends(get_async_session),
) -> list[ExperimentSample]:
) -> list[ExperimentResponse]:
"""
Retrieve all experiments for the current user's workspace.
"""
Expand All @@ -119,12 +121,12 @@ async def get_all_experiments(
return all_experiments


@router.get("/type/{experiment_type}", response_model=list[ExperimentSample])
@router.get("/type/{experiment_type}", response_model=list[ExperimentResponse])
async def get_all_experiments_by_type(
experiment_type: ExperimentsEnum,
user_db: Annotated[UserDB, Depends(get_verified_user)],
asession: AsyncSession = Depends(get_async_session),
) -> list[ExperimentSample]:
) -> list[ExperimentResponse]:
"""
Retrieve all experiments for the current user's workspace.
"""
Expand All @@ -149,12 +151,12 @@ async def get_all_experiments_by_type(
return all_experiments


@router.get("/id/{experiment_id}", response_model=ExperimentSample)
@router.get("/id/{experiment_id}", response_model=ExperimentResponse)
async def get_experiment_by_id(
experiment_id: int,
user_db: Annotated[UserDB, Depends(get_verified_user)],
asession: AsyncSession = Depends(get_async_session),
) -> ExperimentSample:
) -> ExperimentResponse:
"""
Retrieve a specific experiment by ID for the current user's workspace.
"""
Expand Down Expand Up @@ -330,7 +332,7 @@ async def draw_experiment_arm(
)

# -- Perform the draw ---
experiment_data = ExperimentSample.model_validate(experiment.to_dict())
experiment_data = ExperimentResponse.model_validate(experiment.to_dict())

# Validate contexts input
if contexts:
Expand Down Expand Up @@ -476,7 +478,77 @@ async def get_rewards(
"arm": [arm for arm in experiment.arms if arm.arm_id == draw.arm_id][0],
"reward": draw.reward,
"context_val": draw.context_val,
"current_alpha": draw.current_alpha,
"current_beta": draw.current_beta,
"current_mu": draw.current_mu,
"current_covariance": draw.current_covariance,
}
)
for draw in draws
]


@router.get("/{experiment_id}/plotting", response_model=PlottingData)
async def get_plotting_data(
experiment_id: int,
workspace_db: WorkspaceDB = Depends(authenticate_workspace_key),
asession: AsyncSession = Depends(get_async_session),
) -> PlottingData:
"""
Retrieve the data required for plotting.
"""
experiment = await get_experiment_by_id_from_db(
workspace_id=workspace_db.workspace_id,
experiment_id=experiment_id,
asession=asession,
)

if not experiment:
raise HTTPException(
status_code=404, detail=f"Experiment with id {experiment_id} not found"
)
draws = await get_draws_by_experiment_id(
experiment_id=experiment_id, asession=asession
)

if not draws:
raise HTTPException(
status_code=404,
detail=f"No draws found for experiment with id {experiment_id}",
)

experiment_data = ExperimentResponse.model_validate(experiment.to_dict())
draw_data = [
DrawResponse.model_validate(
{
"draw_id": draw.draw_id,
"draw_datetime_utc": str(draw.draw_datetime_utc),
"observed_datetime_utc": str(draw.observed_datetime_utc),
"arm": [arm for arm in experiment.arms if arm.arm_id == draw.arm_id][0],
"reward": draw.reward,
"context_val": draw.context_val,
"current_alpha": draw.current_alpha,
"current_beta": draw.current_beta,
"current_mu": draw.current_mu,
"current_covariance": draw.current_covariance,
}
)
for draw in draws
]
try:
plotting_data = get_required_plotting_data(
experiment=experiment_data, draws=draw_data
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error retrieving plotting data: {str(e)}",
) from e

return PlottingData(
prior_samples=plotting_data["prior_samples"],
posterior_samples=plotting_data["posterior_samples"],
volumes=plotting_data["volumes"],
posterior_means=plotting_data["posterior_means"],
posterior_stds=plotting_data["posterior_stds"],
)
6 changes: 3 additions & 3 deletions backend/app/experiments/sampling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .schemas import (
ArmPriors,
ContextLinkFunctions,
ExperimentSample,
ExperimentResponse,
ExperimentsEnum,
Outcome,
RewardLikelihood,
Expand Down Expand Up @@ -167,7 +167,7 @@ def objective(theta: np.ndarray) -> float:
# ------------- Import functions ----------------
# --- Choose arm function ---
def choose_arm(
experiment: ExperimentSample, context: Optional[Union[list, np.ndarray, None]]
experiment: ExperimentResponse, context: Optional[Union[list, np.ndarray, None]]
) -> int:
"""
Choose arm based on posterior using Thompson Sampling.
Expand Down Expand Up @@ -212,7 +212,7 @@ def choose_arm(

# --- Update arm parameters ---
def update_arm(
experiment: ExperimentSample,
experiment: ExperimentResponse,
rewards: list[float],
arm_to_update: Optional[int] = None,
context: Optional[Union[list, np.ndarray, None]] = None,
Expand Down
37 changes: 25 additions & 12 deletions backend/app/experiments/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,22 @@ class DrawResponse(BaseModel):
description="Context values associated with the draw",
default=None,
)
current_alpha: Optional[Union[float, None]] = Field(
description="Current alpha value of the arm",
default=None,
)
current_beta: Optional[Union[float, None]] = Field(
description="Current beta value of the arm",
default=None,
)
current_mu: Optional[List[Union[float, None]]] = Field(
description="Current mean value of the arm",
default=None,
)
current_covariance: Optional[List[List[Union[float, None]]]] = Field(
description="Current covariance matrix of the arm",
default=None,
)
arm: ArmResponse
client: Optional[Client] = None

Expand Down Expand Up @@ -516,33 +532,30 @@ def check_contexts(self) -> Self:

class ExperimentResponse(ExperimentBase):
"""
Pydantic model for a response for experiment creation
Pydantic model for experiments for drawing and updating arms.
"""

experiment_id: int
n_trials: int
last_trial_datetime_utc: Optional[str] = None
observation_type: ObservationType = ObservationType.USER

arms: list[ArmResponse]
notifications: list[NotificationsResponse]
contexts: Optional[list[ContextResponse]] = None
clients: Optional[list[Client]] = None

model_config = ConfigDict(from_attributes=True)


class ExperimentSample(ExperimentBase):
class PlottingData(BaseModel):
"""
Pydantic model for experiments for drawing and updating arms.
Pydantic model for the data required for plotting.
"""

experiment_id: int
n_trials: int
last_trial_datetime_utc: Optional[str] = None
observation_type: ObservationType = ObservationType.USER

arms: list[ArmResponse]
contexts: Optional[list[ContextResponse]] = None
clients: Optional[list[Client]] = None
posterior_means: list
posterior_stds: list
volumes: list
prior_samples: list
posterior_samples: list

model_config = ConfigDict(from_attributes=True)
Loading