diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 897189483..28188df92 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -240,15 +240,8 @@ def _validate_all_tables(self, data): return errors - def validate(self, data): - """Validate the data. - - Validate that the metadata matches the data and thta every table's constraints are valid. - - Args: - data (dict): - A dictionary of table names to pd.DataFrames. - """ + def _validate(self, data): + """Validate metadata, constraints, and data.""" errors = [] constraints_errors = [] self.metadata.validate_data(data) @@ -268,6 +261,17 @@ def validate(self, data): elif errors: raise InvalidDataError(errors) + def validate(self, data): + """Validate the data. + + Validate that the metadata matches the data and thta every table's constraints are valid. + + Args: + data (dict): + A dictionary of table names to pd.DataFrames. + """ + self._validate(data) + def _validate_table_name(self, table_name): if table_name not in self._table_synthesizers: raise ValueError( @@ -368,8 +372,8 @@ def preprocess(self, data): """ list_of_changed_tables = self._store_and_convert_original_cols(data) - data = self._transform_helper(data) self.validate(data) + data = self._transform_helper(data) if self._fitted: warnings.warn( 'This model has already been fitted. To use the new preprocessed data, '