diff --git a/sapientml_core/explain/main.py b/sapientml_core/explain/main.py index 0b17761..57783b5 100644 --- a/sapientml_core/explain/main.py +++ b/sapientml_core/explain/main.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import collections from typing import Literal, Optional import pandas as pd from sapientml.params import CancellationToken from sapientml.util.logging import setup_logger +from sapientml_preprocess.generator import check_cols_has_symbols, remove_symbols from .AutoEDA import EDA from .AutoVisualization import AutoVisualization_Class @@ -81,12 +83,46 @@ def process( if visualization: # Call AutoVisualization to generate visualization codes AV = AutoVisualization_Class() - visualization_code = AV.AutoVisualization( - df=dataframe, - target_columns=target_columns, - problem_type=problem_type, - ignore_columns=ignore_columns, - ) + cols_has_symbols = check_cols_has_symbols(dataframe.columns.to_list()) + no_symbol_columns = [col for col in dataframe.columns.values if col not in cols_has_symbols] + rename_dict = {} + if cols_has_symbols: + df = list( + dataframe.rename(columns=lambda col: remove_symbols(col) if col in cols_has_symbols else col).columns + ) + rename_dict = {} + same_column = {k: v for k, v in collections.Counter(df).items() if v > 1 and k in no_symbol_columns} + for target, org_column in zip(df, dataframe.columns.tolist()): + if target in same_column.keys(): + rename_dict[target + str(same_column[target] - 1)] = org_column + same_column[target] = same_column[target] - 1 + else: + rename_dict[target] = org_column + + df = list(rename_dict.values()) + + if len(rename_dict) != 0: + col_has_target = [target for target in rename_dict.keys() if rename_dict.values() == target_columns] + visualization_code = AV.AutoVisualization( + df=dataframe, + target_columns=col_has_target, + problem_type=problem_type, + ignore_columns=ignore_columns, + ) + else: + visualization_code = AV.AutoVisualization( + df=dataframe, + target_columns=target_columns, + problem_type=problem_type, + ignore_columns=ignore_columns, + ) + else: + visualization_code = AV.AutoVisualization( + df=dataframe, + target_columns=target_columns, + problem_type=problem_type, + ignore_columns=ignore_columns, + ) else: visualization_code = None diff --git a/sapientml_core/generator.py b/sapientml_core/generator.py index 251e68e..2b089d9 100644 --- a/sapientml_core/generator.py +++ b/sapientml_core/generator.py @@ -221,8 +221,43 @@ def generate_pipeline(self, dataset: Dataset, task: Task): for pipeline in sapientml_results: pipeline.validation = code_block.validation + pipeline.validation pipeline.test = code_block.test + pipeline.test - pipeline.train = code_block.train + pipeline.train pipeline.predict = code_block.predict + pipeline.predict + if "cols_has_symbols" in pipeline.test: + pipeline.test = pipeline.test.replace( + '"feature": feature_train.columns', + '"feature": feature_train.rename(columns=rename_symbol_cols).columns', + ) + pipeline.test = pipeline.test.replace( + "prediction.to_csv", "prediction.rename(columns=rename_symbol_cols).to_csv" + ) + + pipeline.predict = pipeline.predict.replace( + '"feature": feature_train.columns', + '"feature": feature_train.rename(columns=rename_symbol_cols).columns', + ) + pipeline.predict = pipeline.predict.replace( + "prediction.to_csv", "prediction.rename(columns=rename_symbol_cols).to_csv" + ) + + pipeline.validation = pipeline.validation.replace( + '"feature": feature_train.columns', + '"feature": feature_train.rename(columns=rename_symbol_cols).columns', + ) + pipeline.validation = pipeline.validation.replace( + "prediction.to_csv", "prediction.rename(columns=rename_symbol_cols).to_csv" + ) + + def replace_targets(match_obj): + return match_obj[0].replace( + "TARGET_COLUMNS", "[rename_symbol_cols.get(v, v) for v in TARGET_COLUMNS]" + ) + + pat = r"prediction = pd.DataFrame\(y_prob, columns=.?TARGET_COLUMNS.*, index=feature_test.index\)" + pipeline.test = re.sub(pat, replace_targets, pipeline.test) + pipeline.predict = re.sub(pat, replace_targets, pipeline.predict) + pipeline.validation = re.sub(pat, replace_targets, pipeline.validation) + + pipeline.train = code_block.train + pipeline.train result_pipelines.append(pipeline) logger.info("Executing generated pipelines...")