Skip to content

Commit 94344e4

Browse files
UI for whitebox server using streamlit (#131)
1 parent dde8450 commit 94344e4

40 files changed

+2562
-37
lines changed

.streamlit/config.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
[theme]
2+
base="dark"
3+
primaryColor="#21babe"
4+
backgroundColor="#1e2025"
5+
secondaryBackgroundColor="#252a33"
6+

docs/mkdocs/docs/sdk-docs.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@ This is the documentation for Whitebox's SDK. For an interactive experience, you
44

55
## Models
66

7-
**_create_model_**_(name, type, prediction, labels=None, description="")_
7+
**_create_model_**_(name, type, target_column, labels=None, description="")_
88

99
Creates a model in the database. This model works as placeholder for all the actual model's metadata.
1010

1111
| Parameter | Type | Description |
1212
| --------------- | ---------------- | ------------------------------------------------------------------------- |
1313
| **name** | `str` | The name of the model. |
1414
| **type** | `str` | The model's type. Possible values: `binary`, `multi_class`, `regression`. |
15-
| **prediction** | `str` | The prediction of the model. |
15+
| **target_column** | `str` | The name of the target column (y). |
1616
| **labels** | `Dict[str, int]` | The model's labels. Defaults to `None`. |
1717
| **description** | `str` | The model's description. Defaults to an empty string `""`. |
1818

docs/mkdocs/docs/tutorial/sdk.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ wb.create_model(
7373
'additionalProp1': 0,
7474
'additionalProp2': 1
7575
},
76-
prediction="target"
76+
target_column="target"
7777
)
7878
```
7979

examples/notebooks/sdk-example.ipynb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
}
7575
],
7676
"source": [
77-
"wb.create_model(name=\"Model 1\", type=\"binary\", labels={'additionalProp1': 0, 'additionalProp2': 1}, prediction=\"y_prediction_multi\")"
77+
"wb.create_model(name=\"Model 1\", type=\"binary\", labels={'additionalProp1': 0, 'additionalProp2': 1}, target_column=\"y_prediction_multi\")"
7878
]
7979
},
8080
{
@@ -248,7 +248,7 @@
248248
],
249249
"metadata": {
250250
"kernelspec": {
251-
"display_name": "Python 3",
251+
"display_name": ".venv",
252252
"language": "python",
253253
"name": "python3"
254254
},
@@ -262,12 +262,12 @@
262262
"name": "python",
263263
"nbconvert_exporter": "python",
264264
"pygments_lexer": "ipython3",
265-
"version": "3.10.8"
265+
"version": "3.8.6"
266266
},
267267
"orig_nbformat": 4,
268268
"vscode": {
269269
"interpreter": {
270-
"hash": "32a5a47fe20cdfbd609287887f2e78a4d5f2f7afeda3da775d5794970f9a3f8e"
270+
"hash": "6b8a8ae524dcca06b04542d2d49be160be0f11dd19d43d7b2673d555344c6092"
271271
}
272272
}
273273
},

requirements.txt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,16 +88,29 @@ setuptools==67.1.0
8888
six==1.16.0
8989
sniffio==1.3.0
9090
SQLAlchemy==1.4.46
91+
stack-data==0.5.1
9192
starlette==0.23.1
9293
statsmodels==0.13.5
94+
streamlit==1.18.1
9395
tenacity==8.2.1
9496
threadpoolctl==3.1.0
9597
tifffile==2023.2.3
9698
tomli==2.0.1
9799
tqdm==4.64.1
100+
traitlets==5.4.0
101+
typed-ast==1.5.4
102+
typer==0.6.1
103+
types-python-dateutil==2.8.19
98104
typing_extensions==4.4.0
105+
unicodedata2==14.0.0
99106
urllib3==1.26.14
100107
uvicorn==0.20.0
101108
virtualenv==20.19.0
102109
watchdog==2.2.1
110+
wcwidth==0.2.5
103111
wheel==0.38.4
112+
wrapt==1.14.1
113+
xgboost==1.6.2
114+
yarl==1.8.1
115+
zipp==3.10.0
116+

whitebox/.streamlit/config.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
[theme]
2+
base="dark"
3+
primaryColor="#21babe"
4+
backgroundColor="#1e2025"
5+
secondaryBackgroundColor="#252a33"
6+

whitebox/api/v1/dataset_rows.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,15 @@ async def create_dataset_rows(
4545
model = crud.models.get(db=db, _id=dict(body[0])["model_id"])
4646
if model:
4747
for row in body:
48-
if not model.prediction in row.processed:
48+
if not model.target_column in row.processed:
4949
return errors.bad_request(
50-
f'Column "{model.prediction}" was not found in some or any of the rows in provided training dataset. Please try again!'
50+
f'Column "{model.target_column}" was not found in some or any of the rows in provided training dataset. Please try again!'
5151
)
5252

53-
predictions = list(set(vars(x)["processed"][model.prediction] for x in body))
53+
predictions = list(set(vars(x)["processed"][model.target_column] for x in body))
5454
if len(predictions) <= 1:
5555
return errors.bad_request(
56-
f'Training dataset\'s "{model.prediction}" columns must have at least 2 different values!'
56+
f'Training dataset\'s "{model.target_column}" columns must have at least 2 different values!'
5757
)
5858

5959
new_dataset_rows = crud.dataset_rows.create_many(db=db, obj_list=body)
@@ -66,21 +66,21 @@ async def create_dataset_rows(
6666
background_tasks.add_task(
6767
create_binary_classification_training_model_pipeline,
6868
processed_dataset_rows_pd,
69-
model.prediction,
69+
model.target_column,
7070
model.id,
7171
)
7272
elif model.type == ModelType.multi_class:
7373
background_tasks.add_task(
7474
create_multiclass_classification_training_model_pipeline,
7575
processed_dataset_rows_pd,
76-
model.prediction,
76+
model.target_column,
7777
model.id,
7878
)
7979
elif model.type == ModelType.regression:
8080
background_tasks.add_task(
8181
create_regression_training_model_pipeline,
8282
processed_dataset_rows_pd,
83-
model.prediction,
83+
model.target_column,
8484
model.id,
8585
)
8686
return new_dataset_rows

whitebox/api/v1/inference_rows.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,18 @@ async def create_many_inference_rows(
5050
) -> List[InferenceRow]:
5151
"""Inserts a set of inference rows into the database."""
5252

53-
new_inference_rows = crud.inference_rows.create_many(db=db, obj_list=body)
54-
return new_inference_rows
53+
model = crud.models.get(db=db, _id=dict(body[0])["model_id"])
54+
if model:
55+
for row in body:
56+
if not model.target_column in row.processed:
57+
return errors.bad_request(
58+
f'Column "{model.target_column}" was not found in some or any of the rows in provided inference dataset. Please try again!'
59+
)
60+
61+
new_inference_rows = crud.inference_rows.create_many(db=db, obj_list=body)
62+
return new_inference_rows
63+
else:
64+
return errors.not_found(f"Model with id: {dict(body[0])['model_id']} not found")
5565

5666

5767
@inference_rows_router.get(
@@ -138,7 +148,7 @@ async def create_inference_row_xai_report(
138148

139149
xai_report = create_xai_pipeline_per_inference_row(
140150
training_set=pd.DataFrame(dataset_rows_processed),
141-
target=model.prediction,
151+
target=model.target_column,
142152
inference_row=inference_row_series,
143153
type_of_task=model.type,
144154
model_id=model.id,

whitebox/cron_tasks/monitoring_metrics.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,10 @@ async def run_calculate_drifting_metrics_pipeline(
5151

5252
# We need to drop the target column from the data to calculate drifting metrics
5353
processed_inference_dropped_target_df = inference_processed_df.drop(
54-
[model.prediction], axis=1
54+
[model.target_column], axis=1
5555
)
5656
processed_training_dropped_target_df = training_processed_df.drop(
57-
[model.prediction], axis=1
57+
[model.target_column], axis=1
5858
)
5959

6060
data_drift_report = run_data_drift_pipeline(
@@ -63,7 +63,7 @@ async def run_calculate_drifting_metrics_pipeline(
6363
concept_drift_report = run_concept_drift_pipeline(
6464
training_processed_df,
6565
inference_processed_df,
66-
model.prediction,
66+
model.target_column,
6767
)
6868

6969
new_drifting_metric = entities.DriftingMetric(
@@ -115,7 +115,7 @@ async def run_calculate_performance_metrics_pipeline(
115115
if model.type == ModelType.binary:
116116
binary_classification_metrics_report = (
117117
create_binary_classification_evaluation_metrics_pipeline(
118-
cleaned_actuals_df, inference_processed_df[model.prediction], labels
118+
cleaned_actuals_df, inference_processed_df[model.target_column], labels
119119
)
120120
)
121121

@@ -130,7 +130,7 @@ async def run_calculate_performance_metrics_pipeline(
130130
elif model.type == ModelType.multi_class:
131131
multiclass_classification_metrics_report = (
132132
create_multiple_classification_evaluation_metrics_pipeline(
133-
cleaned_actuals_df, inference_processed_df[model.prediction], labels
133+
cleaned_actuals_df, inference_processed_df[model.target_column], labels
134134
)
135135
)
136136

@@ -144,7 +144,7 @@ async def run_calculate_performance_metrics_pipeline(
144144

145145
elif model.type == ModelType.regression:
146146
regression_metrics_report = create_regression_evaluation_metrics_pipeline(
147-
cleaned_actuals_df, inference_processed_df[model.prediction]
147+
cleaned_actuals_df, inference_processed_df[model.target_column]
148148
)
149149

150150
new_performance_metric = entities.RegressionMetrics(

whitebox/entities/Model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class Model(Base):
1313
description = Column(String)
1414
type = Column("type", Enum(ModelType))
1515
labels = Column(JSON, nullable=True)
16-
prediction = Column(String)
16+
target_column = Column(String)
1717
created_at = Column(DateTime)
1818
updated_at = Column(DateTime)
1919

0 commit comments

Comments
 (0)