|
41 | 41 | "! zenml login https://1cf18d95-zenml.cloudinfra.zenml.io \n", |
42 | 42 | "\n", |
43 | 43 | "import IPython\n", |
| 44 | + "\n", |
44 | 45 | "IPython.Application.instance().kernel.do_shutdown(restart=True)" |
45 | 46 | ] |
46 | 47 | }, |
|
73 | 74 | "outputs": [], |
74 | 75 | "source": [ |
75 | 76 | "# Do the imports at the top\n", |
76 | | - "from typing_extensions import Annotated\n", |
77 | | - "from sklearn.datasets import load_breast_cancer\n", |
| 77 | + "from uuid import UUID\n", |
78 | 78 | "\n", |
79 | 79 | "import pandas as pd\n", |
80 | | - "from zenml import step, pipeline, Model, get_step_context\n", |
| 80 | + "from pipelines import feature_engineering, training\n", |
| 81 | + "from sklearn.datasets import load_breast_cancer\n", |
| 82 | + "from steps import data_loader, inference_preprocessor\n", |
| 83 | + "from typing_extensions import Annotated\n", |
| 84 | + "from zenml import Model, get_step_context, pipeline, step\n", |
81 | 85 | "from zenml.client import Client\n", |
82 | 86 | "from zenml.logger import get_logger\n", |
83 | | - "from uuid import UUID\n", |
84 | | - "\n", |
85 | | - "from zenml import pipeline\n", |
86 | | - "\n", |
87 | | - "from steps import (\n", |
88 | | - " data_loader,\n", |
89 | | - " inference_preprocessor\n", |
90 | | - ")\n", |
91 | | - "from pipelines import feature_engineering, training\n", |
92 | 87 | "\n", |
93 | 88 | "logger = get_logger(__name__)\n", |
94 | 89 | "\n", |
|
126 | 121 | "@step\n", |
127 | 122 | "def data_loader_simplified(\n", |
128 | 123 | " random_state: int, is_inference: bool = False, target: str = \"target\"\n", |
129 | | - ") -> Annotated[pd.DataFrame, \"dataset\"]: # We name the dataset \n", |
| 124 | + ") -> Annotated[pd.DataFrame, \"dataset\"]: # We name the dataset\n", |
130 | 125 | " \"\"\"Dataset reader step.\"\"\"\n", |
131 | 126 | " dataset = load_breast_cancer(as_frame=True)\n", |
132 | 127 | " inference_size = int(len(dataset.target) * 0.05)\n", |
133 | 128 | " dataset: pd.DataFrame = dataset.frame\n", |
134 | | - " inference_subset = dataset.sample(inference_size, random_state=random_state)\n", |
| 129 | + " inference_subset = dataset.sample(\n", |
| 130 | + " inference_size, random_state=random_state\n", |
| 131 | + " )\n", |
135 | 132 | " if is_inference:\n", |
136 | 133 | " dataset = inference_subset\n", |
137 | 134 | " dataset.drop(columns=target, inplace=True)\n", |
138 | 135 | " else:\n", |
139 | 136 | " dataset.drop(inference_subset.index, inplace=True)\n", |
140 | 137 | " dataset.reset_index(drop=True, inplace=True)\n", |
141 | 138 | " logger.info(f\"Dataset with {len(dataset)} records loaded!\")\n", |
142 | | - " return dataset\n" |
| 139 | + " return dataset" |
143 | 140 | ] |
144 | 141 | }, |
145 | 142 | { |
|
243 | 240 | "metadata": {}, |
244 | 241 | "outputs": [], |
245 | 242 | "source": [ |
246 | | - "feature_engineering(random_state=42,test_size=0.25)" |
| 243 | + "feature_engineering(random_state=42, test_size=0.25)" |
247 | 244 | ] |
248 | 245 | }, |
249 | 246 | { |
|
340 | 337 | "outputs": [], |
341 | 338 | "source": [ |
342 | 339 | "# Get artifact version from our run\n", |
343 | | - "dataset_trn_artifact_version_via_run = run.steps[\"data_preprocessor\"].outputs[\"dataset_trn\"] \n", |
| 340 | + "dataset_trn_artifact_version_via_run = run.steps[\"data_preprocessor\"].outputs[\n", |
| 341 | + " \"dataset_trn\"\n", |
| 342 | + "]\n", |
344 | 343 | "\n", |
345 | 344 | "# Get latest version from client directly\n", |
346 | 345 | "dataset_trn_artifact_version = client.get_artifact_version(\"dataset_trn\")\n", |
|
359 | 358 | "source": [ |
360 | 359 | "# Fetch the rest of the artifacts\n", |
361 | 360 | "dataset_tst_artifact_version = client.get_artifact_version(\"dataset_tst\")\n", |
362 | | - "preprocessing_pipeline_artifact_version = client.get_artifact_version(\"preprocess_pipeline\")" |
| 361 | + "preprocessing_pipeline_artifact_version = client.get_artifact_version(\n", |
| 362 | + " \"preprocess_pipeline\"\n", |
| 363 | + ")" |
363 | 364 | ] |
364 | 365 | }, |
365 | 366 | { |
|
480 | 481 | "training.with_options(enable_cache=False)(\n", |
481 | 482 | " model_type=\"xgboost\",\n", |
482 | 483 | " train_dataset_id=dataset_trn_artifact_version.id,\n", |
483 | | - " test_dataset_id=dataset_tst_artifact_version.id\n", |
| 484 | + " test_dataset_id=dataset_tst_artifact_version.id,\n", |
484 | 485 | ")\n", |
485 | 486 | "\n", |
486 | 487 | "xgboost_run = client.get_pipeline(\"training\").last_run" |
|
497 | 498 | "sgd_run = training.with_options(enable_cache=False)(\n", |
498 | 499 | " model_type=\"sgd\",\n", |
499 | 500 | " train_dataset_id=dataset_trn_artifact_version.id,\n", |
500 | | - " test_dataset_id=dataset_tst_artifact_version.id\n", |
| 501 | + " test_dataset_id=dataset_tst_artifact_version.id,\n", |
501 | 502 | ")\n", |
502 | 503 | "\n", |
503 | 504 | "sgd_run = client.get_pipeline(\"training\").last_run" |
|
521 | 522 | "outputs": [], |
522 | 523 | "source": [ |
523 | 524 | "# The evaluator returns a float value with the accuracy\n", |
524 | | - "xgboost_run.steps[\"model_evaluator\"].output.load() >= sgd_run.steps[\"model_evaluator\"].output.load()" |
| 525 | + "xgboost_run.steps[\"model_evaluator\"].output.load() >= sgd_run.steps[\n", |
| 526 | + " \"model_evaluator\"\n", |
| 527 | + "].output.load()" |
525 | 528 | ] |
526 | 529 | }, |
527 | 530 | { |
|
579 | 582 | "training_configured(\n", |
580 | 583 | " model_type=\"sgd\",\n", |
581 | 584 | " train_dataset_id=dataset_trn_artifact_version.id,\n", |
582 | | - " test_dataset_id=dataset_tst_artifact_version.id\n", |
| 585 | + " test_dataset_id=dataset_tst_artifact_version.id,\n", |
583 | 586 | ")" |
584 | 587 | ] |
585 | 588 | }, |
|
601 | 604 | "training_configured(\n", |
602 | 605 | " model_type=\"xgboost\",\n", |
603 | 606 | " train_dataset_id=dataset_trn_artifact_version.id,\n", |
604 | | - " test_dataset_id=dataset_tst_artifact_version.id\n", |
| 607 | + " test_dataset_id=dataset_tst_artifact_version.id,\n", |
605 | 608 | ")" |
606 | 609 | ] |
607 | 610 | }, |
|
650 | 653 | "outputs": [], |
651 | 654 | "source": [ |
652 | 655 | "# Let's load the XGBoost version\n", |
653 | | - "xgboost_zenml_model_version = client.list_model_versions(\"breast_cancer_classifier\", tag=\"xgboost\")[-1]\n", |
| 656 | + "xgboost_zenml_model_version = client.list_model_versions(\n", |
| 657 | + " \"breast_cancer_classifier\", tag=\"xgboost\"\n", |
| 658 | + ")[-1]\n", |
654 | 659 | "\n", |
655 | 660 | "# We can now load our classifier directly as well\n", |
656 | | - "xgboost_classifier = xgboost_zenml_model_version.get_artifact(\"breast_cancer_classifier\").load()\n", |
| 661 | + "xgboost_classifier = xgboost_zenml_model_version.get_artifact(\n", |
| 662 | + " \"breast_cancer_classifier\"\n", |
| 663 | + ").load()\n", |
657 | 664 | "\n", |
658 | 665 | "xgboost_classifier" |
659 | 666 | ] |
|
750 | 757 | "outputs": [], |
751 | 758 | "source": [ |
752 | 759 | "@step\n", |
753 | | - "def inference_predict(dataset_inf: pd.DataFrame) -> Annotated[pd.Series, \"predictions\"]:\n", |
| 760 | + "def inference_predict(\n", |
| 761 | + " dataset_inf: pd.DataFrame,\n", |
| 762 | + ") -> Annotated[pd.Series, \"predictions\"]:\n", |
754 | 763 | " \"\"\"Predictions step\"\"\"\n", |
755 | 764 | " # Get the model\n", |
756 | 765 | " model = get_step_context().model\n", |
|
761 | 770 | "\n", |
762 | 771 | " predictions = pd.Series(predictions, name=\"predicted\")\n", |
763 | 772 | "\n", |
764 | | - " return predictions\n" |
| 773 | + " return predictions" |
765 | 774 | ] |
766 | 775 | }, |
767 | 776 | { |
|
788 | 797 | " random_state = 42\n", |
789 | 798 | " target = \"target\"\n", |
790 | 799 | "\n", |
791 | | - " df_inference = data_loader(\n", |
792 | | - " random_state=random_state, is_inference=True\n", |
793 | | - " )\n", |
| 800 | + " df_inference = data_loader(random_state=random_state, is_inference=True)\n", |
794 | 801 | " df_inference = inference_preprocessor(\n", |
795 | 802 | " dataset_inf=df_inference,\n", |
796 | 803 | " # We use the preprocess pipeline from the feature engineering pipeline\n", |
797 | | - " preprocess_pipeline=client.get_artifact_version(name_id_or_prefix=preprocess_pipeline_id),\n", |
| 804 | + " preprocess_pipeline=client.get_artifact_version(\n", |
| 805 | + " name_id_or_prefix=preprocess_pipeline_id\n", |
| 806 | + " ),\n", |
798 | 807 | " target=target,\n", |
799 | 808 | " )\n", |
800 | 809 | " inference_predict(\n", |
801 | 810 | " dataset_inf=df_inference,\n", |
802 | | - " )\n" |
| 811 | + " )" |
803 | 812 | ] |
804 | 813 | }, |
805 | 814 | { |
|
823 | 832 | "# Lets add some metadata to the model to make it identifiable\n", |
824 | 833 | "pipeline_settings[\"model\"] = Model(\n", |
825 | 834 | " name=\"breast_cancer_classifier\",\n", |
826 | | - " version=\"production\", # We can pass in the stage name here!\n", |
| 835 | + " version=\"production\", # We can pass in the stage name here!\n", |
827 | 836 | ")" |
828 | 837 | ] |
829 | 838 | }, |
|
863 | 872 | "outputs": [], |
864 | 873 | "source": [ |
865 | 874 | "# Fetch production model\n", |
866 | | - "production_model_version = client.get_model_version(\"breast_cancer_classifier\", \"production\")\n", |
| 875 | + "production_model_version = client.get_model_version(\n", |
| 876 | + " \"breast_cancer_classifier\", \"production\"\n", |
| 877 | + ")\n", |
867 | 878 | "\n", |
868 | 879 | "# Get the predictions artifact\n", |
869 | 880 | "production_model_version.get_artifact(\"predictions\").load()" |
|
895 | 906 | "outputs": [], |
896 | 907 | "source": [ |
897 | 908 | "from zenml.client import Client\n", |
| 909 | + "\n", |
898 | 910 | "client = Client()" |
899 | 911 | ] |
900 | 912 | }, |
|
905 | 917 | "metadata": {}, |
906 | 918 | "outputs": [], |
907 | 919 | "source": [ |
908 | | - "sgd_model_version = client.list_model_versions(\"breast_cancer_classifier\",tag=\"sgd\")[-1]\n", |
909 | | - "xgboost_model_version = client.list_model_versions(\"breast_cancer_classifier\",tag=\"xgboost\")[-1]\n", |
| 920 | + "sgd_model_version = client.list_model_versions(\n", |
| 921 | + " \"breast_cancer_classifier\", tag=\"sgd\"\n", |
| 922 | + ")[-1]\n", |
| 923 | + "xgboost_model_version = client.list_model_versions(\n", |
| 924 | + " \"breast_cancer_classifier\", tag=\"xgboost\"\n", |
| 925 | + ")[-1]\n", |
910 | 926 | "print(f\"SGD version is staged as `{sgd_model_version.stage}`\")\n", |
911 | 927 | "print(f\"XGBoost version is staged as `{xgboost_model_version.stage}`\")" |
912 | 928 | ] |
|
975 | 991 | "metadata": {}, |
976 | 992 | "outputs": [], |
977 | 993 | "source": [ |
978 | | - "sgd_clf_metadata = sgd_model_version.get_artifact(\"breast_cancer_classifier\").run_metadata\n", |
979 | | - "xgboost_clf_metadata = xgboost_model_version.get_artifact(\"breast_cancer_classifier\").run_metadata\n", |
980 | | - "print(f\"SGD{' (production)' if sgd_model_version.stage == 'production' else ''} metrics: train={sgd_clf_metadata['train_accuracy'].value*100:.2f}% test={sgd_clf_metadata['test_accuracy'].value*100:.2f}%\")\n", |
981 | | - "print(f\"XGBoost{' (production)' if xgboost_model_version.stage == 'production' else ''} metrics: train={xgboost_clf_metadata['train_accuracy'].value*100:.2f}% test={xgboost_clf_metadata['test_accuracy'].value*100:.2f}%\")" |
| 994 | + "sgd_clf_metadata = sgd_model_version.get_artifact(\n", |
| 995 | + " \"breast_cancer_classifier\"\n", |
| 996 | + ").run_metadata\n", |
| 997 | + "xgboost_clf_metadata = xgboost_model_version.get_artifact(\n", |
| 998 | + " \"breast_cancer_classifier\"\n", |
| 999 | + ").run_metadata\n", |
| 1000 | + "print(\n", |
| 1001 | + " f\"SGD{' (production)' if sgd_model_version.stage == 'production' else ''} metrics: train={sgd_clf_metadata['train_accuracy'].value*100:.2f}% test={sgd_clf_metadata['test_accuracy'].value*100:.2f}%\"\n", |
| 1002 | + ")\n", |
| 1003 | + "print(\n", |
| 1004 | + " f\"XGBoost{' (production)' if xgboost_model_version.stage == 'production' else ''} metrics: train={xgboost_clf_metadata['train_accuracy'].value*100:.2f}% test={xgboost_clf_metadata['test_accuracy'].value*100:.2f}%\"\n", |
| 1005 | + ")" |
982 | 1006 | ] |
983 | 1007 | }, |
984 | 1008 | { |
|
996 | 1020 | "metadata": {}, |
997 | 1021 | "outputs": [], |
998 | 1022 | "source": [ |
999 | | - "import seaborn as sns\n", |
1000 | | - "import numpy as np\n", |
1001 | 1023 | "import matplotlib.pyplot as plt\n", |
| 1024 | + "import numpy as np\n", |
| 1025 | + "import seaborn as sns\n", |
1002 | 1026 | "\n", |
1003 | | - "def plot_confusion_matrix(metadata_pointer, tp: str,ax):\n", |
1004 | | - " confusion_matrix = np.array(metadata_pointer[\"confusion_matrix\"].value, dtype=float).reshape((2,2))\n", |
| 1027 | + "\n", |
| 1028 | + "def plot_confusion_matrix(metadata_pointer, tp: str, ax):\n", |
| 1029 | + " confusion_matrix = np.array(\n", |
| 1030 | + " metadata_pointer[\"confusion_matrix\"].value, dtype=float\n", |
| 1031 | + " ).reshape((2, 2))\n", |
1005 | 1032 | " confusion_matrix /= np.sum(confusion_matrix)\n", |
1006 | | - " sns.heatmap(confusion_matrix, annot=True,fmt='.2%',cmap=\"coolwarm\",ax=ax)\n", |
| 1033 | + " sns.heatmap(\n", |
| 1034 | + " confusion_matrix, annot=True, fmt=\".2%\", cmap=\"coolwarm\", ax=ax\n", |
| 1035 | + " )\n", |
1007 | 1036 | " ax.set_title(f\"{tp} confusion matrix\")\n", |
1008 | 1037 | " ax.set_ylabel(\"Ground Label\")\n", |
1009 | 1038 | " ax.set_xlabel(\"Predicted Label\")\n", |
1010 | 1039 | "\n", |
1011 | | - "fig, ax = plt.subplots(1,2,figsize=(15,4))\n", |
1012 | | - "plot_confusion_matrix(sgd_clf_metadata, \"SGD\",ax[0])\n", |
1013 | | - "plot_confusion_matrix(xgboost_clf_metadata, \"RF\",ax[1])" |
| 1040 | + "\n", |
| 1041 | + "fig, ax = plt.subplots(1, 2, figsize=(15, 4))\n", |
| 1042 | + "plot_confusion_matrix(sgd_clf_metadata, \"SGD\", ax[0])\n", |
| 1043 | + "plot_confusion_matrix(xgboost_clf_metadata, \"RF\", ax[1])" |
1014 | 1044 | ] |
1015 | 1045 | }, |
1016 | 1046 | { |
|
1052 | 1082 | "for artifact_name, versions in sgd_model_version.data_artifacts.items():\n", |
1053 | 1083 | " if versions:\n", |
1054 | 1084 | " print(f\"Existing version of `{artifact_name}`:\")\n", |
1055 | | - " for version_name, artifact_ in versions.items():\n", |
| 1085 | + " for version_name, artifact_ in versions.items():\n", |
1056 | 1086 | " print(version_name, artifact_.data_type.attribute)" |
1057 | 1087 | ] |
1058 | 1088 | }, |
|
0 commit comments