Skip to content

Commit a14469b

Browse files
committed
error handling
1 parent a68fbab commit a14469b

File tree

1 file changed

+71
-16
lines changed

1 file changed

+71
-16
lines changed

zephyr_ml/core.py

Lines changed: 71 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def try_log_backwards_set_method_warning(self, name, next_step):
125125
f"\tAll other steps' results will be considered stale."))
126126

127127
def try_log_backwards_key_method_warning(self, name, next_step):
128-
steps_in_between = self.get_steps_in_between(next_step, self.current_step+1)
128+
steps_in_between = self.get_steps_in_between(next_step, self.current_step + 1)
129129
if len(steps_in_between) > 0:
130130
steps_in_between_str = (f"\tAny results produced by the following steps "
131131
f"will be considered stale:\n"
@@ -353,7 +353,7 @@ def __init__(self):
353353
[self.get_train_test_split]),
354354
([self.fit_pipeline], [self.set_fitted_pipeline], [self.get_fitted_pipeline]),
355355
([self.predict, self.evaluate], [], [])
356-
]
356+
]
357357
self._guide_handler = GuideHandler(step_order)
358358

359359
def GET_ENTITYSET_TYPES(self):
@@ -553,7 +553,9 @@ def generate_label_times(
553553
AssertionError: If entityset has not been generated or set or labeling_fn is
554554
not a string and not callable.
555555
"""
556-
assert self._entityset is not None, "entityset has not been set"
556+
557+
if self._entityset is None:
558+
raise ValueError("entityset has not been set")
557559

558560
if isinstance(labeling_fn, str): # get predefined labeling function
559561
labeling_fn_map = get_labeling_functions_map()
@@ -630,6 +632,9 @@ def get_label_times(self, visualize=False):
630632
Returns:
631633
tuple: (composeml.LabelTimes, dict) The label times and metadata.
632634
"""
635+
if self._label_times is None:
636+
raise ValueError("Label times have not been set"
637+
"Call generate_label_times or set_label_times first.")
633638
if visualize:
634639
cp.label_times.plots.LabelPlots(self._label_times).distribution()
635640
return self._label_times, self._label_times_meta
@@ -724,7 +729,20 @@ def generate_feature_matrix(
724729
Returns:
725730
tuple: (pd.DataFrame, list, featuretools.EntitySet)
726731
Feature matrix, feature definitions, and the processed entityset.
732+
733+
Raises:
734+
ValueError: If required attributes are missing.
727735
"""
736+
if self._entityset is None:
737+
raise ValueError(
738+
"Entityset has not been set. Call generate_entityset or "
739+
"set_entityset first.")
740+
741+
if self._label_times is None:
742+
raise ValueError(
743+
"Label times have not been set. Call generate_label_times or "
744+
"set_label_times first.")
745+
728746
entityset_copy = copy.deepcopy(self._entityset)
729747
# perform signal processing
730748
if signal_dataframe_name is not None and signal_column is not None:
@@ -784,6 +802,9 @@ def get_feature_matrix(self):
784802
tuple: (pd.DataFrame, str, list) The feature matrix, label column name,
785803
and feature definitions.
786804
"""
805+
if self._feature_matrix is None:
806+
raise ValueError("Feature matrix has not been generated. "
807+
"Call generate_feature_matrix or set_feature_matrix first.")
787808
return self._feature_matrix, self._label_col_name, self._features
788809

789810
@guide
@@ -830,6 +851,11 @@ def generate_train_test_split(
830851
Returns:
831852
tuple: (X_train, X_test, y_train, y_test) The split feature matrices and labels.
832853
"""
854+
if self._feature_matrix is None:
855+
raise ValueError(
856+
"Feature matrix has not been generated. Call generate_feature_matrix "
857+
"or set_feature_matrix first.")
858+
833859
feature_matrix = self._feature_matrix.copy()
834860
labels = feature_matrix.pop(self._label_col_name)
835861

@@ -880,7 +906,9 @@ def get_train_test_split(self):
880906
"""
881907
if (self._X_train is None or self._X_test is None or
882908
self._y_train is None or self._y_test is None):
883-
return None
909+
raise ValueError(
910+
"Train-test split has not been generated. "
911+
"Call generate_train_test_split or set_train_test_split first.")
884912
return self._X_train, self._X_test, self._y_train, self._y_test
885913

886914
@guide
@@ -894,8 +922,8 @@ def set_fitted_pipeline(self, pipeline):
894922

895923
@guide
896924
def fit_pipeline(
897-
self, pipeline="xgb_classifier", pipeline_hyperparameters=None,
898-
X=None, y=None, visual=False, **kwargs):
925+
self, pipeline="xgb_classifier",
926+
pipeline_hyperparameters=None, visual=False, **kwargs):
899927
"""Fit a machine learning pipeline.
900928
901929
Args:
@@ -905,28 +933,29 @@ def fit_pipeline(
905933
- Dictionary with pipeline specification
906934
- MLPipeline instance
907935
pipeline_hyperparameters (dict, optional): Hyperparameters for the pipeline.
908-
X (pd.DataFrame, optional): Training features. If None, uses stored training set.
909-
y (array-like, optional): Training labels. If None, uses stored training labels.
910936
visual (bool, optional): Whether to return visualization data. Defaults to False.
911937
**kwargs: Additional arguments passed to the pipeline's fit method.
912938
913939
Returns:
914940
dict or None: If visual=True, returns visualization data dictionary.
941+
942+
Raises:
943+
ValueError: If required attributes are missing.
915944
"""
916-
self._pipeline = self._get_mlpipeline(
917-
pipeline, pipeline_hyperparameters)
945+
if self._X_train is None or self._y_train is None:
946+
raise ValueError(
947+
"No training data provided. Call generate_train_test_split "
948+
"or set_train_test_split first.")
918949

919-
if X is None:
920-
X = self._X_train
921-
if y is None:
922-
y = self._y_train
950+
self._pipeline = self._get_mlpipeline(pipeline, pipeline_hyperparameters)
923951

924952
if visual:
925953
outputs_spec, visual_names = self._get_outputs_spec(False)
926954
else:
927955
outputs_spec = None
928956

929-
outputs = self._pipeline.fit(X, y, output_=outputs_spec, **kwargs)
957+
outputs = self._pipeline.fit(X=self._X_train, y=self._y_train,
958+
output_=outputs_spec, **kwargs)
930959

931960
if visual and outputs is not None:
932961
return dict(zip(visual_names, outputs))
@@ -951,9 +980,22 @@ def predict(self, X=None, visual=False, **kwargs):
951980
952981
Returns:
953982
array-like or tuple: Predictions, and if visual=True, also returns visualization data.
983+
984+
Raises:
985+
ValueError: If required attributes or parameters are missing.
954986
"""
955-
if X is None:
987+
if self._pipeline is None:
988+
raise ValueError(
989+
"No pipeline has been fitted. Call fit_pipeline or set_fitted_pipeline first.")
990+
991+
if X is None and self._X_test is None:
992+
raise ValueError(
993+
"No test data provided. Pass in test data or "
994+
"call generate_train_test_split or set_train_test_split first.")
995+
996+
elif X is None:
956997
X = self._X_test
998+
957999
if visual:
9581000
outputs_spec, visual_names = self._get_outputs_spec()
9591001
else:
@@ -984,9 +1026,22 @@ def evaluate(
9841026
9851027
Returns:
9861028
dict: A dictionary mapping metric names to their computed values.
1029+
1030+
Raises:
1031+
ValueError: If required attributes are missing.
9871032
"""
1033+
if self._pipeline is None:
1034+
raise ValueError(
1035+
"No pipeline has been fitted. Call fit_pipeline or set_fitted_pipeline first.")
1036+
1037+
if (X is None and self._X_test is None) or (y is None and self._y_test is None):
1038+
raise ValueError(
1039+
"No test data provided. Pass in test data or "
1040+
"call generate_train_test_split or set_train_test_split first.")
1041+
9881042
if X is None:
9891043
X = self._X_test
1044+
9901045
if y is None:
9911046
y = self._y_test
9921047

0 commit comments

Comments
 (0)