Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add interaction plot in the report #64

Merged
merged 1 commit into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
strategy:
max-parallel: 1
matrix:
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ sd.generate_report(

## 🛠 Installation

Eurybia is intended to work with Python versions 3.9 to 3.11. Installation can be done with pip:
Eurybia is intended to work with Python versions 3.9 to 3.12. Installation can be done with pip:

```
pip install eurybia
Expand Down
9 changes: 5 additions & 4 deletions eurybia/core/smartdrift.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import shutil
import tempfile
from pathlib import Path
from typing import Dict

import catboost
import pandas as pd
Expand Down Expand Up @@ -199,12 +198,12 @@ def __init__(
def compile(
self,
full_validation=False,
ignore_cols: list = list(),
ignore_cols: list = None,
sampling=True,
sample_size=100000,
datadrift_file=None,
date_compile_auc=None,
hyperparameter: Dict = catboost_hyperparameter_init.copy(),
hyperparameter: dict = catboost_hyperparameter_init.copy(),
attr_importance="feature_importances_",
):
r"""
Expand Down Expand Up @@ -237,6 +236,8 @@ def compile(
>>> SD.compile()

"""
if ignore_cols is None:
ignore_cols = []
if datadrift_file is not None:
self.datadrift_file = datadrift_file
if hyperparameter is not None:
Expand Down Expand Up @@ -468,7 +469,7 @@ def _analyze_consistency(self, full_validation=False, ignore_cols: list = list()
and will not be analyzed: \n {err_dtypes}"""
)
# Feature values
err_mods: Dict[str, Dict] = {}
err_mods: dict[str, dict] = {}
if full_validation is True:
invalid_cols = ignore_cols + new_cols + removed_cols + err_dtypes
for column in self.df_baseline.columns:
Expand Down
2 changes: 1 addition & 1 deletion eurybia/core/smartplotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ def generate_modeldrift_data(
if data_modeldrift is None:
data_modeldrift = self.smartdrift.data_modeldrift
if data_modeldrift is None:
raise Exception(
raise ValueError(
"""You should run the add_data_modeldrift method before displaying model drift performances.
For more information see the documentation"""
)
Expand Down
16 changes: 14 additions & 2 deletions eurybia/report/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,6 @@ def get_data_drift_panel(dr: DriftReport) -> pn.Column:
pn.pane.Markdown("### Univariate analysis"),
pn.pane.Markdown(report_text["Data drift"]["07"]),
]
contribution_figures, contribution_labels = dr.display_model_contribution()

distribution_figures, labels, distribution_tables = dr.display_dataset_analysis(global_analysis=False)["univariate"]
distribution_plots_blocks = get_select_plots(
Expand Down Expand Up @@ -262,6 +261,9 @@ def get_data_drift_panel(dr: DriftReport) -> pn.Column:
max_gauge=0.2,
)
blocks += [pn.pane.Plotly(js_fig)]

contribution_figures, contribution_labels = dr.display_model_contribution()

blocks += [
pn.pane.Markdown("## Feature contribution on data drift's detection"),
pn.pane.Markdown(report_text["Data drift"]["09"]),
Expand All @@ -273,14 +275,24 @@ def get_data_drift_panel(dr: DriftReport) -> pn.Column:
figures=contribution_figures,
)
blocks += contribution_plots_blocks

fig_02 = dr.explainer.plot.top_interactions_plot(nb_top_interactions=10)
fig_02.update_layout(width=1240)
blocks += [
pn.pane.Markdown("## Feature interaction on data drift's detection"),
pn.pane.Markdown(report_text["Data drift"]["10"]),
pn.pane.Plotly(fig_02),
]

if dr.smartdrift.historical_auc is not None:
fig = dr.smartdrift.plot.generate_historical_datadrift_metric()
fig.update_layout(width=1240)
blocks += [
pn.pane.Markdown("## Historical Data drift"),
pn.pane.Markdown(report_text["Data drift"]["10"]),
pn.pane.Markdown(report_text["Data drift"]["11"]),
pn.pane.Plotly(fig),
]

return pn.Column(*blocks, name="Data drift", styles=dict(display="none"), css_classes=["data-drift"])


Expand Down
12 changes: 6 additions & 6 deletions eurybia/report/project_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import copy
import logging
import os
from typing import Dict, Optional, Union
from typing import Optional, Union

import jinja2
import pandas as pd
Expand Down Expand Up @@ -36,11 +36,11 @@ class DriftReport:
Attributes
----------
smartdrift: object
SmartDrift object
SmartDrift object
explainer : shapash.explainer.smart_explainer.SmartExplainer
A shapash SmartExplainer object that has already be compiled
A shapash SmartExplainer object that has already be compiled
title_story : str
Report title
Report title
metadata : dict
Information about the project (author, description, ...)
df_predict : pd.DataFrame
Expand All @@ -56,7 +56,7 @@ def __init__(
smartdrift: SmartDrift,
explainer: SmartExplainer,
project_info_file: Optional[str] = None,
config_report: Optional[Dict] = None,
config_report: Optional[dict] = None,
):
"""
Parameters
Expand Down Expand Up @@ -253,7 +253,7 @@ def display_model_contribution(self):
c_list = self.explainer._classes if multiclass else [1] # list just used for multiclass
plot_list = []
labels = []
for index_label, label in enumerate(c_list): # Iterating over all labels in multiclass case
for label in c_list: # Iterating over all labels in multiclass case
for feature in self.features_imp_list:
fig = self.explainer.plot.contribution_plot(feature, label=label, max_points=200)
plot_list.append(fig)
Expand Down
11 changes: 8 additions & 3 deletions eurybia/report/properties.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Dict
from typing import Any

report_text: Dict[str, Any] = {
report_text: dict[str, Any] = {
"Index": {
"01": "- Project information: report context and information",
"02": "- Consistency Analysis: highlighting differences between the two datasets",
Expand Down Expand Up @@ -77,7 +77,12 @@
"This representation constitutes a support to understand the drift "
"when the analysis of the dataset is unclear."
),
"10": ("Line chart showing the metrics evolution of the datadrift classifier over the given period of time."),
"10": (
"This graph represents the interactions between couple of variable to the data drift detection. "
"This representation constitutes a support to understand the drift "
"when the analysis of the dataset is unclear."
),
"11": ("Line chart showing the metrics evolution of the datadrift classifier over the given period of time."),
},
"Model drift": {
"01": (
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ classifiers = [
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
]
Expand Down
Loading